aleo1 commited on
Commit
9f3352f
1 Parent(s): 4072e7b

Upload 23 files

Browse files
cisen/config/cisen_r0.9_fpn.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: classification
3
+ dataset_json_file: /data02/xy/dataEngine/json_data/LuojiaHOG(test)_.json
4
+ # dataset_json_file: /data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json
5
+ # dataset_json_file: /data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json
6
+ exp_name: classifi
7
+ ratio: 0
8
+ dataset_train_split: 0.6
9
+ dataset_query_split: 0.2
10
+ imgs_folder: /data02/xy/Clip-hash/datasets/image/
11
+ label_path: /data02/xy/Clip-hash/labels.txt
12
+ num_classes: 10
13
+ # num_classes: 131
14
+ TRAIN:
15
+ # Base Arch
16
+ # clip_pretrain: /data02/xy/Clip-hash/pretrain/RS5M_ViT-B-32.pt
17
+ clip_pretrain: ./cisen/pretrain/RS5M_ViT-B-32.pt
18
+ model_name: ViT-B-32
19
+ ckpt_path: /data02/xy/GeoRSCLIP/codebase/inference/pretrain/RS5M_ViT-B-32.pt
20
+ input_size: 224
21
+ word_len: 328
22
+ word_dim: 1024
23
+ vis_dim: 512
24
+ fpn_in: [ 512, 768, 768 ]
25
+ fpn_out: [ 768, 768, 768, 512 ]
26
+ sync_bn: True
27
+ # Decoder
28
+ num_layers: 3
29
+ num_head: 8
30
+ dim_ffn: 2048
31
+ dropout: 0.1
32
+ intermediate: False
33
+ # Training Setting
34
+ workers: 32 # data loader workers
35
+ workers_val: 16
36
+ epochs: 50
37
+ milestones: [50]
38
+ start_epoch: 0
39
+ batch_size: 256 # batch size for training
40
+ batch_size_val: 256 # batch size for validation during training, memory and speed tradeoff 11111
41
+ base_lr: 0.0001
42
+ min_lr: 0.00000001
43
+ lr_decay: 0.5
44
+ lr_multi: 0.1
45
+ weight_decay: 0.
46
+ max_norm: 0.
47
+ manual_seed: 0
48
+ print_freq: 1
49
+ lamda1: 0.5
50
+ lamda2: 0.5
51
+ beta1: 0.5
52
+ beta2: 0.5
53
+ eta: 0.2
54
+ warmup_epochs: 0
55
+ contrastive: [0.4, 0.3, 0.3]
56
+ # Resume & Save
57
+
58
+ output_folder: /data02/xy/Clip-hash/exp/
59
+ save_freq: 1
60
+ weight: # path to initial weight (default: none)
61
+ resume: False # path to latest checkpoint (default: none)
62
+ evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
63
+ Distributed:
64
+ dist_url: tcp://localhost:3693
65
+ dist_backend: 'nccl'
66
+ multiprocessing_distributed: True
67
+ world_size: 1
68
+ rank: 0
69
+ TEST:
70
+ test_split: val-test
71
+ gpu : [0]
72
+ test_lmdb: /data02/xy/Clip-hash/datasets/lmdb/refcoco/val.lmdb
73
+ visualize: False
74
+ topk: 5
75
+ test_batch_size: 256 #1111111
76
+ val_batch_size: 1
cisen/engine/__init__.py ADDED
File without changes
cisen/engine/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (133 Bytes). View file
 
cisen/engine/__pycache__/engine.cpython-38.pyc ADDED
Binary file (7.95 kB). View file
 
cisen/engine/demo.py ADDED
File without changes
cisen/engine/engine.py ADDED
The diff for this file is too large to render. See raw diff
 
cisen/model/__init__.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .segmenter import CRIS, CISEN, Clip_hash_model, zh_clip, poi_clip, Clip_model, CISEN_vit, CISEN_rsvit, CISEN_new, CISEN_rsvit_classification, CISEN_lclip
2
+ from .segmenter import *
3
+ from loguru import logger
4
+ from transformers import AlignProcessor, AlignModel
5
+ # def build_segmenter(args):
6
+ # model = CRIS(args)
7
+ # backbone = []
8
+ # backbone_no_decay = []
9
+ # head = []
10
+ # for k, v in model.named_parameters():
11
+ # if k.startswith('backbone') and 'positional_embedding' not in k:
12
+ # backbone.append(v)
13
+ # elif 'positional_embedding' in k:
14
+ # backbone_no_decay.append(v)
15
+ # else:
16
+ # head.append(v)
17
+ # print('Backbone with decay: {}, Backbone without decay: {}, Head: {}'.format(
18
+ # len(backbone), len(backbone_no_decay), len(head)))
19
+ # param_list = [{
20
+ # 'params': backbone,
21
+ # 'initial_lr': args.lr_multi * args.base_lr
22
+ # }, {
23
+ # 'params': backbone_no_decay,
24
+ # 'initial_lr': args.lr_multi * args.base_lr,
25
+ # 'weight_decay': 0
26
+ # }, {
27
+ # 'params': head,
28
+ # 'initial_lr': args.base_lr
29
+ # }]
30
+ # return model, param_list
31
+
32
+
33
+
34
+ def build_CISEN(args, stage):
35
+ model = CISEN_new(args)
36
+ backbone = []
37
+ head = []
38
+ ADP = []
39
+ ADP_t = []
40
+ fuse = []
41
+ name = []
42
+ for k, v in model.named_parameters():
43
+ if k.startswith('backbone') and 'backbone.positional_embedding' not in k:
44
+ # if k.startswith('backbone'):
45
+ v.requires_grad = False
46
+ backbone.append(v)
47
+ elif k.startswith('ADP'):
48
+ # v.requires_grad = False
49
+ ADP.append(v)
50
+ elif k.startswith('FPN'):
51
+ fuse.append(v)
52
+ elif k.startswith('gap'):
53
+ fuse.append(v)
54
+ elif k.startswith('ADP_t'):
55
+ ADP_t.append(v)
56
+ else:
57
+ head.append(v)
58
+ name.append(k)
59
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
60
+ # param_list = [{
61
+ # 'params': backbone,
62
+ # 'initial_lr': args.lr_multi * float(args.base_lr)
63
+ # }, {
64
+ # 'params': head,
65
+ # 'initial_lr': args.base_lr
66
+ # }, {
67
+ # 'params': proj,
68
+ # 'initial_lr': args.base_lr
69
+ # }]
70
+ if stage == '1st':
71
+ param_list = [{
72
+ 'params': ADP,
73
+ 'initial_lr': args.base_lr
74
+ },{
75
+ 'params': head,
76
+ 'initial_lr': args.base_lr
77
+ }]
78
+ elif stage == '2nd':
79
+ param_list = [{
80
+ 'params': fuse,
81
+ 'initial_lr': args.base_lr
82
+ }]
83
+ elif stage == '4th':
84
+ param_list = [{
85
+ 'params': fuse,
86
+ 'initial_lr': args.base_lr
87
+ }]
88
+ elif stage == '5th':
89
+ param_list = [{
90
+ # 'params': ADP,
91
+ # 'initial_lr': args.base_lr
92
+ # },{
93
+ # 'params': ADP_t,
94
+ # 'initial_lr': args.base_lr
95
+ # },{
96
+ 'params': fuse,
97
+ 'initial_lr': args.base_lr
98
+ }]
99
+ else:
100
+ print('stage should be either 1st or 2nd')
101
+ return model, param_list
102
+
103
+ def build_CISEN_lclip(args, stage):
104
+ model = CISEN_lclip(args)
105
+ backbone = []
106
+ head = []
107
+ ADP = []
108
+ ADP_t = []
109
+ fuse = []
110
+ name = []
111
+ for k, v in model.named_parameters():
112
+ # if k.startswith('backbone') and 'backbone.positional_embedding' not in k:
113
+ if k.startswith('backbone'):
114
+ v.requires_grad = False
115
+ backbone.append(v)
116
+ elif k.startswith('ADP'):
117
+ # v.requires_grad = False
118
+ ADP.append(v)
119
+ elif k.startswith('FPN'):
120
+ fuse.append(v)
121
+ elif k.startswith('gap'):
122
+ fuse.append(v)
123
+ elif k.startswith('ADP_t'):
124
+ ADP_t.append(v)
125
+ else:
126
+ head.append(v)
127
+ name.append(k)
128
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
129
+ # param_list = [{
130
+ # 'params': backbone,
131
+ # 'initial_lr': args.lr_multi * float(args.base_lr)
132
+ # }, {
133
+ # 'params': head,
134
+ # 'initial_lr': args.base_lr
135
+ # }, {
136
+ # 'params': proj,
137
+ # 'initial_lr': args.base_lr
138
+ # }]
139
+ if stage == '1st':
140
+ param_list = [{
141
+ 'params': ADP,
142
+ 'initial_lr': args.base_lr
143
+ },{
144
+ 'params': head,
145
+ 'initial_lr': args.base_lr
146
+ }]
147
+ elif stage == '2nd':
148
+ param_list = [{
149
+ 'params': fuse,
150
+ 'initial_lr': args.base_lr
151
+ }]
152
+ elif stage == '4th':
153
+ param_list = [{
154
+ 'params': fuse,
155
+ 'initial_lr': args.base_lr
156
+ }]
157
+ elif stage == '5th':
158
+ param_list = [{
159
+ # 'params': ADP,
160
+ # 'initial_lr': args.base_lr
161
+ # },{
162
+ # 'params': ADP_t,
163
+ # 'initial_lr': args.base_lr
164
+ # },{
165
+ 'params': fuse,
166
+ 'initial_lr': args.base_lr
167
+ }]
168
+ else:
169
+ print('stage should be either 1st or 2nd')
170
+ return model, param_list
171
+
172
+ def build_CISEN_vit(args, stage):
173
+ model = CISEN_rsvit(args)
174
+ backbone = []
175
+ head = []
176
+ ADP = []
177
+ ADP_t = []
178
+ fuse = []
179
+ name = []
180
+ for k, v in model.named_parameters():
181
+ # if k.startswith('backbone') and 'backbone.positional_embedding' not in k:
182
+ if k.startswith('backbone'):
183
+ v.requires_grad = False
184
+ backbone.append(v)
185
+ elif k.startswith('ADP'):
186
+ v.requires_grad = False
187
+ ADP.append(v)
188
+ elif k.startswith('FPN'):
189
+ # v.requires_grad = False
190
+ fuse.append(v)
191
+ elif k.startswith('ms_adaptor'):
192
+ # v.requires_grad = False
193
+ fuse.append(v)
194
+ else:
195
+ head.append(v)
196
+ name.append(k)
197
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
198
+ # param_list = [{
199
+ # 'params': backbone,
200
+ # 'initial_lr': args.lr_multi * float(args.base_lr)
201
+ # }, {
202
+ # 'params': head,
203
+ # 'initial_lr': args.base_lr
204
+ # }, {
205
+ # 'params': proj,
206
+ # 'initial_lr': args.base_lr
207
+ # }]
208
+ if stage == '1st':
209
+ param_list = [{
210
+ 'params': ADP,
211
+ 'initial_lr': args.base_lr
212
+ },{
213
+ 'params': head,
214
+ 'initial_lr': args.base_lr
215
+ }]
216
+ elif stage == '2nd':
217
+ param_list = [{
218
+ 'params': fuse,
219
+ 'initial_lr': args.base_lr
220
+ }]
221
+ elif stage == '4th':
222
+ param_list = [{
223
+ 'params': fuse,
224
+ 'initial_lr': args.base_lr
225
+ }]
226
+ elif stage == '5th':
227
+ param_list = [{
228
+ # 'params': ADP,
229
+ # 'initial_lr': args.base_lr
230
+ # },{
231
+ # 'params': ADP_t,
232
+ # 'initial_lr': args.base_lr
233
+ # },{
234
+ 'params': fuse,
235
+ 'initial_lr': args.base_lr
236
+ }]
237
+ else:
238
+ print('stage should be either 1st or 2nd')
239
+ return model, param_list
240
+
241
+ def build_CISEN_vit_classification(args, stage):
242
+ model = CISEN_rsvit_classification(args)
243
+
244
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
245
+ # param_list = [{
246
+ # 'params': backbone,
247
+ # 'initial_lr': args.lr_multi * float(args.base_lr)
248
+ # }, {
249
+ # 'params': head,
250
+ # 'initial_lr': args.base_lr
251
+ # }, {
252
+ # 'params': proj,
253
+ # 'initial_lr': args.base_lr
254
+ # }]
255
+
256
+ return model
257
+
258
+ def build_segmenter(args):
259
+ model = CRIS(args)
260
+ backbone = []
261
+ head = []
262
+ for k, v in model.named_parameters():
263
+ if k.startswith('backbone') and 'positional_embedding' not in k:
264
+ backbone.append(v)
265
+ elif k.startswith('Label_encoder') and "token_embedding" not in k:
266
+ v.requires_grad = False
267
+ else:
268
+ head.append(v)
269
+
270
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
271
+ param_list = [{
272
+ 'params': backbone,
273
+ 'initial_lr': args.lr_multi * float(args.base_lr)
274
+ }, {
275
+ 'params': head,
276
+ 'initial_lr': args.base_lr
277
+ }]
278
+ return model, param_list
279
+
280
+ def build_hash(args):
281
+ model = Clip_hash_model(args)
282
+ backbone = []
283
+ head = []
284
+ for k, v in model.named_parameters():
285
+ if k.startswith('backbone') and 'positional_embedding' not in k:
286
+ backbone.append(v)
287
+ else:
288
+ head.append(v)
289
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
290
+ param_list = [{
291
+ 'params': backbone,
292
+ 'initial_lr': args.lr_multi * args.base_lr
293
+ }, {
294
+ 'params': head,
295
+ 'initial_lr': args.base_lr
296
+ }]
297
+ return model, param_list
298
+
299
+ def build_zh_segmenter(args):
300
+ model = zh_clip(args)
301
+ backbone = []
302
+ head = []
303
+ for k, v in model.named_parameters():
304
+ if k.startswith('backbone') and 'positional_embedding' not in k:
305
+ backbone.append(v)
306
+ else:
307
+ head.append(v)
308
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
309
+ param_list = [{
310
+ 'params': backbone,
311
+ 'initial_lr': args.lr_multi * args.base_lr
312
+ }, {
313
+ 'params': head,
314
+ 'initial_lr': args.base_lr
315
+ }]
316
+ return model, param_list
317
+
318
+ def build_poi_segmenter(args):
319
+ model = poi_clip(args)
320
+ backbone = []
321
+ head = []
322
+ for k, v in model.named_parameters():
323
+ if k.startswith('backbone') and 'positional_embedding' not in k:
324
+ backbone.append(v)
325
+ else:
326
+ head.append(v)
327
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
328
+ param_list = [{
329
+ 'params': backbone,
330
+ 'initial_lr': args.lr_multi * args.base_lr
331
+ }, {
332
+ 'params': head,
333
+ 'initial_lr': args.base_lr
334
+ }]
335
+ return model, param_list
336
+
337
+ def build_clip(args):
338
+ model = Clip_model(args)
339
+ backbone = []
340
+ head = []
341
+ for k, v in model.named_parameters():
342
+ if k.startswith('backbone') and 'positional_embedding' not in k:
343
+ backbone.append(v)
344
+ else:
345
+ head.append(v)
346
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
347
+ param_list = [{
348
+ 'params': backbone,
349
+ 'initial_lr': args.lr_multi * args.base_lr
350
+ }, {
351
+ 'params': head,
352
+ 'initial_lr': args.base_lr
353
+ }]
354
+ return model, param_list
cisen/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (695 Bytes). View file
 
cisen/model/__pycache__/clip.cpython-38.pyc ADDED
Binary file (16.7 kB). View file
 
cisen/model/__pycache__/layers.cpython-38.pyc ADDED
Binary file (9.07 kB). View file
 
cisen/model/__pycache__/segmenter.cpython-38.pyc ADDED
Binary file (1.66 kB). View file
 
cisen/model/builder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright (c) 2022, Huawei Technologies Co., Ltd. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from mmcv import Registry
18
+ from mmcv import build_from_cfg
19
+
20
+ MODELS = Registry('model')
21
+
22
+
23
+ def build_model(config):
24
+
25
+ return build_from_cfg(config, MODELS)
cisen/model/clip.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from ..utils.dataset import tokenize
9
+ from ..utils.simple_tokenizer import SimpleTokenizer as _Tokenizer
10
+ _tokenizer = _Tokenizer()
11
+
12
+
13
+ class Bottleneck(nn.Module):
14
+ expansion = 4
15
+
16
+ def __init__(self, inplanes, planes, stride=1):
17
+ super().__init__()
18
+
19
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
20
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
21
+ self.bn1 = nn.BatchNorm2d(planes)
22
+
23
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+
26
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
27
+
28
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
29
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
30
+
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.downsample = None
33
+ self.stride = stride
34
+
35
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
36
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
37
+ self.downsample = nn.Sequential(
38
+ OrderedDict([("-1", nn.AvgPool2d(stride)),
39
+ ("0",
40
+ nn.Conv2d(inplanes,
41
+ planes * self.expansion,
42
+ 1,
43
+ stride=1,
44
+ bias=False)),
45
+ ("1", nn.BatchNorm2d(planes * self.expansion))]))
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ identity = x
49
+
50
+ out = self.relu(self.bn1(self.conv1(x)))
51
+ out = self.relu(self.bn2(self.conv2(out)))
52
+ out = self.avgpool(out)
53
+ out = self.bn3(self.conv3(out))
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu(out)
60
+ return out
61
+
62
+
63
+ """
64
+ attenpool used in CRIS (output: C1/C2/C3 3 deiffent feature maps)
65
+ """
66
+ class ModifiedAttentionPool2d(nn.Module):
67
+ def __init__(self,
68
+ spacial_dim: int,
69
+ embed_dim: int,
70
+ num_heads: int,
71
+ output_dim: int = None):
72
+ super().__init__()
73
+ self.spacial_dim = spacial_dim
74
+ self.positional_embedding = nn.Parameter(
75
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
76
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
77
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
78
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
79
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
80
+ self.num_heads = num_heads
81
+ # residual
82
+ self.connect = nn.Sequential(
83
+ nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
84
+ nn.BatchNorm2d(output_dim))
85
+
86
+ def resize_pos_embed(self, pos_embed, input_shpae):
87
+ """Resize pos_embed weights.
88
+ Resize pos_embed using bicubic interpolate method.
89
+ Args:
90
+ pos_embed (torch.Tensor): Position embedding weights.
91
+ input_shpae (tuple): Tuple for (downsampled input image height,
92
+ downsampled input image width).
93
+ pos_shape (tuple): The resolution of downsampled origin training
94
+ image.
95
+ mode (str): Algorithm used for upsampling:
96
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
97
+ ``'trilinear'``. Default: ``'nearest'``
98
+ Return:
99
+ torch.Tensor: The resized pos_embed of shape [B, C, L_new]
100
+ """
101
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
102
+ pos_h = pos_w = self.spacial_dim
103
+ cls_token_weight = pos_embed[:, 0]
104
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
105
+ pos_embed_weight = pos_embed_weight.reshape(
106
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
107
+ pos_embed_weight = F.interpolate(pos_embed_weight,
108
+ size=input_shpae,
109
+ align_corners=False,
110
+ mode='bicubic')
111
+ cls_token_weight = cls_token_weight.unsqueeze(1)
112
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
113
+ # pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
114
+ return pos_embed_weight.transpose(-2, -1)
115
+
116
+ def forward(self, x):
117
+ B, C, H, W = x.size()
118
+ res = self.connect(x)
119
+ x = x.reshape(B, C, -1) # NC(HW)
120
+ # x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
121
+ pos_embed = self.positional_embedding.unsqueeze(0)
122
+ pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
123
+ x = x + pos_embed.to(x.dtype) # NC(HW)
124
+ x = x.permute(2, 0, 1) # (HW)NC
125
+ x, _ = F.multi_head_attention_forward(
126
+ query=x,
127
+ key=x,
128
+ value=x,
129
+ embed_dim_to_check=x.shape[-1],
130
+ num_heads=self.num_heads,
131
+ q_proj_weight=self.q_proj.weight,
132
+ k_proj_weight=self.k_proj.weight,
133
+ v_proj_weight=self.v_proj.weight,
134
+ in_proj_weight=None,
135
+ in_proj_bias=torch.cat(
136
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
137
+ bias_k=None,
138
+ bias_v=None,
139
+ add_zero_attn=False,
140
+ dropout_p=0,
141
+ out_proj_weight=self.c_proj.weight,
142
+ out_proj_bias=self.c_proj.bias,
143
+ use_separate_proj_weight=True,
144
+ training=self.training,
145
+ need_weights=False)
146
+ xt = x[0]
147
+ x = x.permute(1, 2, 0).reshape(B, -1, H, W)
148
+ x = x + res
149
+ x = F.relu(x, True)
150
+
151
+ return x, xt
152
+
153
+
154
+ """
155
+ attenpool used in Clip (output: a tensor (b, dim) image encoding)
156
+ """
157
+ class AttentionPool2d(nn.Module):
158
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
159
+ super().__init__()
160
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
161
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
162
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
163
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
164
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
165
+ self.num_heads = num_heads
166
+
167
+ def forward(self, x):
168
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
169
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
170
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
171
+ x, _ = F.multi_head_attention_forward(
172
+ query=x[:1], key=x, value=x,
173
+ embed_dim_to_check=x.shape[-1],
174
+ num_heads=self.num_heads,
175
+ q_proj_weight=self.q_proj.weight,
176
+ k_proj_weight=self.k_proj.weight,
177
+ v_proj_weight=self.v_proj.weight,
178
+ in_proj_weight=None,
179
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
180
+ bias_k=None,
181
+ bias_v=None,
182
+ add_zero_attn=False,
183
+ dropout_p=0,
184
+ out_proj_weight=self.c_proj.weight,
185
+ out_proj_bias=self.c_proj.bias,
186
+ use_separate_proj_weight=True,
187
+ training=self.training,
188
+ need_weights=False
189
+ )
190
+ return x.squeeze(0)
191
+
192
+
193
+ class ModifiedResNet(nn.Module):
194
+ """
195
+ A ResNet class that is similar to torchvision's but contains the following changes:
196
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
197
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
198
+ - The final pooling layer is a QKV attention instead of an average pool
199
+ """
200
+ def __init__(self,
201
+ layers,
202
+ output_dim,
203
+ heads,
204
+ input_resolution=224,
205
+ width=64):
206
+ super().__init__()
207
+ self.output_dim = output_dim
208
+ self.input_resolution = input_resolution
209
+
210
+ # the 3-layer stem
211
+ self.conv1 = nn.Conv2d(3,
212
+ width // 2,
213
+ kernel_size=3,
214
+ stride=2,
215
+ padding=1,
216
+ bias=False)
217
+ self.bn1 = nn.BatchNorm2d(width // 2)
218
+ self.conv2 = nn.Conv2d(width // 2,
219
+ width // 2,
220
+ kernel_size=3,
221
+ padding=1,
222
+ bias=False)
223
+ self.bn2 = nn.BatchNorm2d(width // 2)
224
+ self.conv3 = nn.Conv2d(width // 2,
225
+ width,
226
+ kernel_size=3,
227
+ padding=1,
228
+ bias=False)
229
+ self.bn3 = nn.BatchNorm2d(width)
230
+ self.avgpool = nn.AvgPool2d(2)
231
+ self.relu = nn.ReLU(inplace=True)
232
+
233
+ # residual layers
234
+ self._inplanes = width # this is a *mutable* variable used during construction
235
+ self.layer1 = self._make_layer(width, layers[0])
236
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
237
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
238
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
239
+
240
+ embed_dim = width * 32 # the ResNet feature dimension
241
+
242
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
243
+ heads, output_dim)
244
+ # self.modifiedattnpool = ModifiedAttentionPool2d(input_resolution // 32, embed_dim,
245
+ # heads, output_dim)
246
+
247
+ def _make_layer(self, planes, blocks, stride=1):
248
+ layers = [Bottleneck(self._inplanes, planes, stride)]
249
+
250
+ self._inplanes = planes * Bottleneck.expansion
251
+ for _ in range(1, blocks):
252
+ layers.append(Bottleneck(self._inplanes, planes))
253
+
254
+ return nn.Sequential(*layers)
255
+
256
+ def forward(self, x):
257
+ def stem(x):
258
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
259
+ (self.conv3, self.bn3)]:
260
+
261
+ x = self.relu(bn(conv(x)))
262
+
263
+ x = self.avgpool(x)
264
+ return x
265
+
266
+ x = x.type(self.conv1.weight.dtype)
267
+ x = stem(x)
268
+
269
+ x = self.layer1(x)
270
+
271
+ x2 = self.layer2(x)
272
+
273
+ x3 = self.layer3(x2)
274
+ x4 = self.layer4(x3)
275
+ x5 = self.attnpool(x4)
276
+ # x4 = self.modifiedattnpool(x4)
277
+
278
+ return (x2, x3, x4), x5
279
+
280
+
281
+ class LayerNorm(nn.LayerNorm):
282
+ """Subclass torch's LayerNorm to handle fp16."""
283
+ def forward(self, x: torch.Tensor):
284
+ orig_type = x.dtype
285
+ ret = super().forward(x.type(torch.float32))
286
+ return ret.type(orig_type)
287
+
288
+
289
+ class QuickGELU(nn.Module):
290
+ def forward(self, x: torch.Tensor):
291
+ return x * torch.sigmoid(1.702 * x)
292
+
293
+
294
+ class ResidualAttentionBlock(nn.Module):
295
+ def __init__(self,
296
+ d_model: int,
297
+ n_head: int,
298
+ attn_mask: torch.Tensor = None):
299
+ super().__init__()
300
+ # print(n_head)
301
+ self.attn = nn.MultiheadAttention(d_model, n_head)
302
+ self.ln_1 = LayerNorm(d_model)
303
+ self.mlp = nn.Sequential(
304
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),
305
+ ("gelu", QuickGELU()),
306
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
307
+ self.ln_2 = LayerNorm(d_model)
308
+ self.attn_mask = attn_mask
309
+
310
+ def attention(self, x: torch.Tensor):
311
+ self.attn_mask = self.attn_mask.to(
312
+ dtype=x.dtype,
313
+ device=x.device) if self.attn_mask is not None else None
314
+ res = self.attn(x, x, x, need_weights=False,
315
+ attn_mask=self.attn_mask)[0]
316
+ # print(res)
317
+ return res
318
+
319
+ def forward(self, x: torch.Tensor):
320
+ # a = self.attention(self.ln_1(x))
321
+ x = x + self.attention(self.ln_1(x))
322
+
323
+ x = x + self.mlp(self.ln_2(x))
324
+ return x
325
+
326
+ class Transformer(nn.Module):
327
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
328
+ super().__init__()
329
+ self.width = width
330
+ self.layers = layers
331
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
332
+
333
+ def forward(self, x: torch.Tensor):
334
+ return self.resblocks(x)
335
+
336
+ class ViTTransformer(nn.Module):
337
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
338
+ super().__init__()
339
+ self.width = width
340
+ self.layers = layers
341
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ outputs = []
345
+ i = 1
346
+ for block in self.resblocks:
347
+ x = block(x)
348
+ if i > 7:
349
+ outputs.append(x)
350
+ i = i + 1
351
+ return outputs
352
+
353
+
354
+ class VisionTransformer(nn.Module):
355
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
356
+ layers: int, heads: int, output_dim: int):
357
+ super().__init__()
358
+ self.input_resolution = input_resolution
359
+ self.output_dim = output_dim
360
+ self.conv1 = nn.Conv2d(in_channels=3,
361
+ out_channels=width,
362
+ kernel_size=patch_size,
363
+ stride=patch_size,
364
+ bias=False)
365
+
366
+ scale = width ** -0.5
367
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
368
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
369
+ (input_resolution // patch_size) ** 2 + 1, width))
370
+ self.ln_pre = LayerNorm(width)
371
+
372
+ self.transformer = ViTTransformer(width, layers, heads)
373
+
374
+ self.ln_post = LayerNorm(width)
375
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
376
+
377
+ def forward(self, x: torch.Tensor):
378
+ # input: batch, 3, 224, 224
379
+
380
+ # batch, 1024, 16, 16
381
+ x = self.conv1(x) # shape = [*, width, grid, grid]
382
+ # batch, 1024, 256
383
+ x = x.reshape(x.shape[0], x.shape[1],
384
+ -1) # shape = [*, width, grid ** 2]
385
+ # batch, 256, 1024
386
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
387
+ # batch, 257, 1024
388
+ x = torch.cat([
389
+ self.class_embedding.to(x.dtype) + torch.zeros(
390
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
391
+ ],
392
+ dim=1) # shape = [*, grid ** 2 + 1, width]
393
+
394
+ x = x + self.positional_embedding.to(x.dtype)
395
+
396
+ x = self.ln_pre(x)
397
+ # 257, batch, 1024
398
+ x = x.permute(1, 0, 2) # NLD -> LND
399
+
400
+ out = self.transformer(x)
401
+ # batch, 257, 1024
402
+ x1, x2 ,x3, x4 = out[0], out[1], out[2], out[3]
403
+ x1 = x1.permute(1, 0, 2)
404
+ x2 = x2.permute(1, 0, 2)
405
+ x3 = x3.permute(1, 0, 2)
406
+ x4 = x4.permute(1, 0, 2) # LND -> NLD
407
+
408
+ # 用于分类
409
+ x = self.ln_post(x4[:, 0, :])
410
+ #feature
411
+ # x_f = self.ln_post(x[:, 1:, :])
412
+
413
+
414
+ if self.proj is not None:
415
+ x = x @ self.proj
416
+
417
+ return (x1[:, 1:, :], x2[:, 1:, :], x3[:, 1:, :], x4[:, 1:, :]), x
418
+
419
+ class ModifiedVisionTransformer(nn.Module):
420
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
421
+ layers: int, heads: int, output_dim: int):
422
+ super().__init__()
423
+ self.input_resolution = input_resolution
424
+ self.output_dim = output_dim
425
+ self.conv1 = nn.Conv2d(in_channels=3,
426
+ out_channels=width,
427
+ kernel_size=patch_size,
428
+ stride=patch_size,
429
+ bias=False)
430
+
431
+ self.conv2 = nn.Conv2d(in_channels=3,
432
+ out_channels=width // 2,
433
+ kernel_size=patch_size // 2,
434
+ stride=patch_size // 2,
435
+ bias=False)
436
+
437
+ self.conv3 = nn.Conv2d(in_channels=3,
438
+ out_channels=width,
439
+ kernel_size=patch_size * 2,
440
+ stride=patch_size * 2,
441
+ bias=False)
442
+ self.conv_layers = [self.conv1, self.conv2]
443
+ scale = width**-0.5
444
+
445
+ self.class_embedding1 = nn.Parameter(scale * torch.randn(width))
446
+ self.class_embedding2 = nn.Parameter(scale * torch.randn(width // 2))
447
+ self.cls_layers = [self.class_embedding1, self.class_embedding2]
448
+
449
+ self.positional_embedding1 = nn.Parameter(scale * torch.randn(
450
+ (input_resolution // patch_size)**2 + 1, width))
451
+ self.positional_embedding2 = nn.Parameter(scale * torch.randn(
452
+ (input_resolution // (patch_size // 2)) ** 2 + 1, width // 2))
453
+ self.pos_layers = [self.positional_embedding1, self.positional_embedding2]
454
+
455
+ self.ln_pre1 = LayerNorm(width)
456
+ self.ln_pre2 = LayerNorm(width // 2)
457
+ self.pre_layers = [self.ln_pre1, self.ln_pre2]
458
+
459
+ self.transformer1 = Transformer(width, layers, heads)
460
+ self.transformer2 = Transformer(width // 2, layers, heads)
461
+ self.tran_layers = [self.transformer1, self.transformer2]
462
+
463
+ self.ln_post1 = LayerNorm(width)
464
+ self.ln_post2 = LayerNorm(width // 2)
465
+ self.post_layers = [self.ln_post1, self.ln_post2]
466
+
467
+ self.proj1 = nn.Parameter(scale * torch.randn(width, output_dim * 2))
468
+ self.proj2 = nn.Parameter(scale * torch.randn(width // 2, output_dim))
469
+ self.proj_layers = [self.proj1, self.proj2]
470
+
471
+
472
+ def forward(self, x: torch.Tensor):
473
+ # input: batch, 3, 224, 224
474
+ input = x
475
+ # batch, 1024, 16, 16
476
+ out = []
477
+ f = []
478
+ cl = []
479
+ for i in range(2):
480
+ x = self.conv_layers[i](input) # shape = [*, width, grid, grid]
481
+
482
+ b, c, w, h = x.shape
483
+ # batch, 1024, 256
484
+ x = x.reshape(x.shape[0], x.shape[1],
485
+ -1) # shape = [*, width, grid ** 2]
486
+ # batch, 256, 1024
487
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
488
+ # batch, 257, 1024
489
+ x = torch.cat([
490
+ self.cls_layers[i].to(x.dtype) + torch.zeros(
491
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
492
+ ],
493
+ dim=1) # shape = [*, grid ** 2 + 1, width]
494
+
495
+ x = x + self.pos_layers[i].to(x.dtype)
496
+
497
+ x = self.pre_layers[i](x)
498
+ # 257, batch, 1024
499
+ x = x.permute(1, 0, 2) # NLD -> LND
500
+
501
+ x, cls = self.tran_layers[i](x)
502
+ # batch, 257, 1024
503
+ x = x.permute(1, 0, 2) # LND -> NLD
504
+
505
+ # 用于分类
506
+ # x = self.ln_post(x[:, 0, :])
507
+ # feature
508
+ x = self.post_layers[i](x[:, 1:, :])
509
+
510
+
511
+
512
+ if self.proj_layers[i] is not None:
513
+ x = x @ self.proj_layers[i]
514
+ cls = [j @ self.proj_layers[i] for j in cls]
515
+
516
+ feat = x.permute(0,2,1).reshape(b, x.shape[2] , w, h)
517
+ out.append(x)
518
+ f.append(feat)
519
+ cl.append(cls)
520
+ return out, f, cl
521
+
522
+ """
523
+ Long CLIP
524
+ """
525
+ class LCLIP(nn.Module):
526
+ def __init__(self,
527
+ embed_dim: int,
528
+ # vision
529
+ image_resolution: int,
530
+ vision_layers: Union[Tuple[int, int, int, int], int],
531
+ vision_width: int,
532
+ vision_patch_size: int,
533
+ # text
534
+ context_length: int,
535
+ vocab_size: int,
536
+ transformer_width: int,
537
+ transformer_heads: int,
538
+ transformer_layers: int,
539
+ load_from_clip: bool
540
+ ):
541
+ super().__init__()
542
+ self.context_length = 248
543
+
544
+ if isinstance(vision_layers, (tuple, list)):
545
+ vision_heads = vision_width * 32 // 64
546
+ self.visual = ModifiedResNet(
547
+ layers=vision_layers,
548
+ output_dim=embed_dim,
549
+ heads=vision_heads,
550
+ input_resolution=image_resolution,
551
+ width=vision_width
552
+ )
553
+ else:
554
+ vision_heads = vision_width // 64
555
+ self.visual = VisionTransformer(
556
+ input_resolution=image_resolution,
557
+ patch_size=vision_patch_size,
558
+ width=vision_width,
559
+ layers=vision_layers,
560
+ heads=vision_heads,
561
+ output_dim=embed_dim
562
+ )
563
+
564
+ self.transformer = Transformer(
565
+ width=transformer_width,
566
+ layers=transformer_layers,
567
+ heads=transformer_heads,
568
+ attn_mask=self.build_attention_mask()
569
+ )
570
+
571
+ self.vocab_size = vocab_size
572
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
573
+ # self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
574
+
575
+ if load_from_clip == False:
576
+ self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
577
+ self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
578
+
579
+ else:
580
+ self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
581
+
582
+ self.ln_final = LayerNorm(transformer_width)
583
+
584
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
585
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
586
+
587
+ self.initialize_parameters()
588
+ self.mask1 = torch.zeros([248, 1])
589
+ self.mask1[:20, :] = 1
590
+ self.mask2 = torch.zeros([248, 1])
591
+ self.mask2[20:, :] = 1
592
+
593
+
594
+ def initialize_parameters(self):
595
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
596
+ nn.init.normal_(self.positional_embedding, std=0.01)
597
+
598
+ if isinstance(self.visual, ModifiedResNet):
599
+ if self.visual.attnpool is not None:
600
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
601
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
602
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
603
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
604
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
605
+
606
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
607
+ for name, param in resnet_block.named_parameters():
608
+ if name.endswith("bn3.weight"):
609
+ nn.init.zeros_(param)
610
+
611
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
612
+ attn_std = self.transformer.width ** -0.5
613
+ fc_std = (2 * self.transformer.width) ** -0.5
614
+ for block in self.transformer.resblocks:
615
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
616
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
617
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
618
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
619
+
620
+ if self.text_projection is not None:
621
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
622
+
623
+ def build_attention_mask(self):
624
+ # lazily create causal attention mask, with full attention between the vision tokens
625
+ # pytorch uses additive attention mask; fill with -inf
626
+ mask = torch.empty(self.context_length, self.context_length)
627
+ mask.fill_(float("-inf"))
628
+ mask.triu_(1) # zero out the lower diagonal
629
+ return mask
630
+
631
+ @property
632
+ def dtype(self):
633
+ return self.visual.conv1.weight.dtype
634
+
635
+ def encode_image(self, image):
636
+ return self.visual(image.type(self.dtype))
637
+
638
+ def encode_text(self, text):
639
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
640
+
641
+ # x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
642
+ x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device)
643
+ x = x.permute(1, 0, 2) # NLD -> LND
644
+ x = self.transformer(x)
645
+ x = x.permute(1, 0, 2) # LND -> NLD
646
+ x = self.ln_final(x).type(self.dtype)
647
+
648
+ # x.shape = [batch_size, n_ctx, transformer.width]
649
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
650
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
651
+
652
+ return x
653
+
654
+ def encode_text_full(self, text):
655
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
656
+
657
+ x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
658
+
659
+ x = x.permute(1, 0, 2) # NLD -> LND
660
+ x = self.transformer(x)
661
+ x = x.permute(1, 0, 2) # LND -> NLD
662
+ x = self.ln_final(x).type(self.dtype)
663
+
664
+ # x.shape = [batch_size, n_ctx, transformer.width]
665
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
666
+ #x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
667
+
668
+ return x
669
+
670
+
671
+ def forward(self, image, text):
672
+ image_features = self.encode_image(image)
673
+ text_features, _ = self.encode_text(text)
674
+
675
+ # normalized features
676
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
677
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
678
+
679
+ # cosine similarity as logits
680
+ logit_scale = self.logit_scale.exp()
681
+ logits_per_image = logit_scale * image_features @ text_features.t()
682
+ logits_per_text = logits_per_image.t()
683
+
684
+ # shape = [global_batch_size, global_batch_size]
685
+ return logits_per_image, logits_per_text
686
+ """
687
+ original CLIP
688
+ """
689
+ class CLIP(nn.Module):
690
+ def __init__(
691
+ self,
692
+ embed_dim: int,
693
+ # vision
694
+ image_resolution: int,
695
+ vision_layers: Union[Tuple[int, int, int, int], int],
696
+ vision_width: int,
697
+ vision_patch_size: int,
698
+ # text
699
+ context_length: int,
700
+ txt_length: int,
701
+ vocab_size: int,
702
+ transformer_width: int,
703
+ transformer_heads: int,
704
+ transformer_layers: int):
705
+ super().__init__()
706
+
707
+ self.context_length = context_length
708
+
709
+ if isinstance(vision_layers, (tuple, list)):
710
+ vision_heads = vision_width * 32 // 64
711
+ self.visual = ModifiedResNet(layers=vision_layers,
712
+ output_dim=embed_dim,
713
+ heads=vision_heads,
714
+ input_resolution=image_resolution,
715
+ width=vision_width)
716
+ # self.fq_attnpool = AttentionPool2d(image_resolution // 32, vision_width* 32,
717
+ # vision_heads, embed_dim)
718
+ else:
719
+ vision_heads = vision_width // 64
720
+ self.visual = VisionTransformer(input_resolution=image_resolution,
721
+ patch_size=vision_patch_size,
722
+ width=vision_width,
723
+ layers=vision_layers,
724
+ heads=vision_heads,
725
+ output_dim=embed_dim)
726
+
727
+ self.transformer = Transformer(
728
+ width=transformer_width,
729
+ layers=transformer_layers,
730
+ heads=transformer_heads,
731
+ attn_mask=self.build_attention_mask(txt_length))
732
+
733
+ self.vocab_size = vocab_size
734
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
735
+ self.positional_embedding = nn.Parameter(
736
+ torch.empty(self.context_length, transformer_width))
737
+ self.ln_final = LayerNorm(transformer_width)
738
+
739
+ self.text_projection = nn.Parameter(
740
+ torch.empty(transformer_width, embed_dim))
741
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
742
+
743
+ self.token_embedding.requires_grad_ = False
744
+ self.initialize_parameters()
745
+
746
+ def initialize_parameters(self):
747
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
748
+ nn.init.normal_(self.positional_embedding, std=0.01)
749
+
750
+ if isinstance(self.visual, ModifiedResNet):
751
+ if self.visual.attnpool is not None:
752
+ std = self.visual.attnpool.c_proj.in_features**-0.5
753
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
754
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
755
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
756
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
757
+
758
+ for resnet_block in [
759
+ self.visual.layer1, self.visual.layer2, self.visual.layer3,
760
+ self.visual.layer4
761
+ ]:
762
+ for name, param in resnet_block.named_parameters():
763
+ if name.endswith("bn3.weight"):
764
+ nn.init.zeros_(param)
765
+
766
+ proj_std = (self.transformer.width**-0.5) * (
767
+ (2 * self.transformer.layers)**-0.5)
768
+ attn_std = self.transformer.width**-0.5
769
+ fc_std = (2 * self.transformer.width)**-0.5
770
+ for block in self.transformer.resblocks:
771
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
772
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
773
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
774
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
775
+
776
+ if self.text_projection is not None:
777
+ nn.init.normal_(self.text_projection,
778
+ std=self.transformer.width**-0.5)
779
+
780
+ def build_attention_mask(self, context_length):
781
+ # lazily create causal attention mask, with full attention between the vision tokens
782
+ # pytorch uses additive attention mask; fill with -inf
783
+ mask = torch.empty(context_length, context_length)
784
+ mask.fill_(float("-inf"))
785
+ mask.triu_(1) # zero out the lower diagonal
786
+ return mask
787
+
788
+ @property
789
+ def dtype(self):
790
+ return self.visual.conv1.weight.dtype
791
+
792
+ def encode_image(self, image):
793
+ return self.visual(image.type(self.dtype))
794
+
795
+ def encode_fq(self, image):
796
+ return self.fq_attnpool(image.type(self.dtype))
797
+
798
+ def encode_text(self, text):
799
+ a = self.token_embedding
800
+ x = self.token_embedding(text).type(
801
+ self.dtype) # [batch_size, n_ctx, d_model]
802
+
803
+ x = x + self.positional_embedding.type(self.dtype)[:x.size(1)]
804
+ # print(x.shape)
805
+ # print(x)
806
+
807
+ x = x.permute(1, 0, 2) # NLD -> LND
808
+ x = self.transformer(x)
809
+ x = x.permute(1, 0, 2) # LND -> NLD
810
+ x = self.ln_final(x).type(self.dtype)
811
+ # print(text[0])
812
+ # x.shape = [batch_size, n_ctx, transformer.width]
813
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
814
+ state = x[torch.arange(x.shape[0]),
815
+ text.argmax(dim=-1)] @ self.text_projection
816
+ # x = x @ self.text_projection
817
+ # state = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
818
+
819
+ return x, state
820
+
821
+ def forward(self, image, text):
822
+ image_features = self.encode_image(image)
823
+ text_features = self.encode_text(text)
824
+
825
+ # normalized features
826
+ image_features = image_features / image_features.norm(dim=-1,
827
+ keepdim=True)
828
+ text_features = text_features / text_features.norm(dim=-1,
829
+ keepdim=True)
830
+
831
+ # cosine similarity as logits
832
+ logit_scale = self.logit_scale.exp()
833
+ logits_per_image = logit_scale * image_features @ text_features.t()
834
+ logits_per_text = logits_per_image.t()
835
+
836
+ # shape = [global_batch_size, global_batch_size]
837
+ return logits_per_image, logits_per_text
838
+
839
+ """
840
+ modified CLIP : without text encoder
841
+ """
842
+
843
+ class zhCLIP(nn.Module):
844
+ def __init__(self,
845
+ embed_dim,
846
+ # vision
847
+ image_resolution: int,
848
+ vision_layers: Union[Tuple[int, int, int, int], int],
849
+ vision_width: int,
850
+ vision_patch_size: int):
851
+ super().__init__()
852
+
853
+
854
+
855
+ if isinstance(vision_layers, (tuple, list)):
856
+ vision_heads = vision_width * 32 // 64
857
+ self.visual = ModifiedResNet(layers=vision_layers,
858
+ output_dim=embed_dim,
859
+ heads=vision_heads,
860
+ input_resolution=image_resolution,
861
+ width=vision_width)
862
+ self.fq_attnpool = AttentionPool2d(image_resolution // 32, vision_width* 32,
863
+ vision_heads, embed_dim)
864
+ else:
865
+ vision_heads = vision_width // 64
866
+ self.visual = ModifiedVisionTransformer(input_resolution=image_resolution,
867
+ patch_size=vision_patch_size,
868
+ width=vision_width,
869
+ layers=vision_layers,
870
+ heads=vision_heads,
871
+ output_dim=embed_dim)
872
+
873
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
874
+ self.initialize_parameters()
875
+
876
+ def initialize_parameters(self):
877
+
878
+ if isinstance(self.visual, ModifiedResNet):
879
+ if self.visual.attnpool is not None:
880
+ std = self.visual.attnpool.c_proj.in_features**-0.5
881
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
882
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
883
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
884
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
885
+
886
+ for resnet_block in [
887
+ self.visual.layer1, self.visual.layer2, self.visual.layer3,
888
+ self.visual.layer4
889
+ ]:
890
+ for name, param in resnet_block.named_parameters():
891
+ if name.endswith("bn3.weight"):
892
+ nn.init.zeros_(param)
893
+
894
+
895
+ def build_attention_mask(self, context_length):
896
+ # lazily create causal attention mask, with full attention between the vision tokens
897
+ # pytorch uses additive attention mask; fill with -inf
898
+ mask = torch.empty(context_length, context_length)
899
+ mask.fill_(float("-inf"))
900
+ mask.triu_(1) # zero out the lower diagonal
901
+ return mask
902
+
903
+ @property
904
+ def dtype(self):
905
+ return self.visual.conv1.weight.dtype
906
+
907
+ def encode_image(self, image):
908
+ return self.visual(image.type(self.dtype))
909
+
910
+ def encode_fq(self, image):
911
+ return self.fq_attnpool(image.type(self.dtype))
912
+
913
+ def forward(self, image, text):
914
+ image_features = self.encode_image(image)
915
+ text_features = self.encode_text(text)
916
+
917
+ # normalized features
918
+ image_features = image_features / image_features.norm(dim=-1,
919
+ keepdim=True)
920
+ text_features = text_features / text_features.norm(dim=-1,
921
+ keepdim=True)
922
+
923
+ # cosine similarity as logits
924
+ logit_scale = self.logit_scale.exp()
925
+ logits_per_image = logit_scale * image_features @ text_features.t()
926
+ logits_per_text = logits_per_image.t()
927
+
928
+ # shape = [global_batch_size, global_batch_size]
929
+ return logits_per_image, logits_per_text
930
+
931
+
932
+ def convert_weights(model: nn.Module):
933
+ """Convert applicable model parameters to fp16"""
934
+ def _convert_weights_to_fp16(l):
935
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
936
+ l.weight.data = l.weight.data.half()
937
+ if l.bias is not None:
938
+ l.bias.data = l.bias.data.half()
939
+
940
+ if isinstance(l, nn.MultiheadAttention):
941
+ for attr in [
942
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
943
+ "in_proj_bias", "bias_k", "bias_v"
944
+ ]:
945
+ tensor = getattr(l, attr)
946
+ if tensor is not None:
947
+ tensor.data = tensor.data.half()
948
+
949
+ for name in ["text_projection", "proj"]:
950
+ if hasattr(l, name):
951
+ attr = getattr(l, name)
952
+ if attr is not None:
953
+ attr.data = attr.data.half()
954
+
955
+ model.apply(_convert_weights_to_fp16)
956
+
957
+ class PromptLearner(nn.Module):
958
+
959
+ def __init__(self, transformer_width, context_length, vocab_size,
960
+ transformer_layers, transformer_heads, bert_embed_dim):
961
+ super().__init__()
962
+
963
+ self.transformer_width = transformer_width
964
+ self.context_length = context_length
965
+ self.vocab_size = vocab_size
966
+ self.token_embedding = nn.Embedding(self.vocab_size, self.transformer_width)
967
+
968
+ self.transformer = Transformer(
969
+ width=transformer_width,
970
+ layers=transformer_layers,
971
+ heads=transformer_heads,
972
+ attn_mask=self.build_attention_mask()
973
+ )
974
+
975
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
976
+ self.ln_final = LayerNorm(transformer_width)
977
+
978
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, bert_embed_dim))
979
+
980
+
981
+ # self.load_from_openai_model(pretrained_model=clip_pretrain)
982
+
983
+ def build_attention_mask(self):
984
+ # lazily create causal attention mask, with full attention between the vision tokens
985
+ # pytorch uses additive attention mask; fill with -inf
986
+ mask = torch.empty(self.context_length, self.context_length)
987
+ mask.fill_(float("-inf"))
988
+ mask.triu_(1) # zero out the lower diagonal
989
+ return mask
990
+
991
+ def init_label_emb(self, labels_path):
992
+
993
+ label = open(labels_path, 'r').readlines()
994
+ # label81 = open(unseen_labels_path, 'r').readlines()
995
+ # label1006 = label925 + label81
996
+ self.name_lens = [len(_tokenizer.encode(name)) for name in label]
997
+ self.label_token = torch.zeros((len(self.name_lens), self.context_length), dtype=torch.long)
998
+ for i, c in enumerate(label):
999
+ self.label_token[i] = tokenize(f"There is a {c.strip()} in the scene")
1000
+ self.label_emb = torch.zeros((len(self.name_lens), max(self.name_lens), self.transformer_width))
1001
+ for i, embed in enumerate(self.token_embedding(self.label_token)):
1002
+ self.label_emb[i][:self.name_lens[i]] = embed[4:4 + self.name_lens[i]].clone().detach()
1003
+
1004
+ # def load_from_openai_model(self, pretrained_model):
1005
+ # state_dict = clip.load(pretrained_model, jit=False)[0].state_dict()
1006
+ # load_dict = {}
1007
+ # for k, v in state_dict.items():
1008
+ # if not k.startswith("visual") and (
1009
+ # k not in ["logit_scale", "input_resolution", "context_length", "vocab_size"]):
1010
+ # load_dict[k] = v
1011
+ # msg = self.load_state_dict(load_dict)
1012
+
1013
+ def load_label_emb(self, label=None):
1014
+ self.name_lens = [len(_tokenizer.encode(name.split("\t")[-1])) for name in label]
1015
+ self.label_token = torch.zeros((len(self.name_lens), self.context_length), dtype=torch.long).cuda()
1016
+ for i, c in enumerate(label):
1017
+ name = c.split("\t")[-1]
1018
+ self.label_token[i] = tokenize(f"There is a {name.strip()} in the scene")
1019
+ self.label_emb = torch.zeros((len(self.name_lens), max(self.name_lens), self.transformer_width)).cuda()
1020
+ for i, embed in enumerate(self.token_embedding(self.label_token)):
1021
+ self.label_emb[i][:self.name_lens[i]] = embed[4:4 + self.name_lens[i]].clone().detach()
1022
+
1023
+ def forward(self, device):
1024
+
1025
+ label_embeds = self.token_embedding(self.label_token.to(device))
1026
+
1027
+ for i in range(label_embeds.shape[0]):
1028
+ label_embeds[i, 4:4 + self.name_lens[i], :] = self.label_emb[i][:self.name_lens[i]]
1029
+
1030
+ x = label_embeds + self.positional_embedding
1031
+ x = x.permute(1, 0, 2) # NLD -> LND
1032
+
1033
+ x = self.transformer(x)
1034
+ x = x.permute(1, 0, 2) # LND -> NLD
1035
+ x = self.ln_final(x)
1036
+
1037
+ res = x[torch.arange(x.shape[0]), self.label_token.argmax(dim=-1)] @ self.text_projection
1038
+
1039
+ return res
1040
+
1041
+ def build_promptlearner(state_dict: dict):
1042
+ embed_dim = state_dict["text_projection"].shape[1]
1043
+ context_length = state_dict["positional_embedding"].shape[0]
1044
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
1045
+ transformer_width = state_dict["ln_final.weight"].shape[0]
1046
+ transformer_heads = transformer_width // 64
1047
+ transformer_layers = len(
1048
+ set(
1049
+ k.split(".")[2] for k in state_dict
1050
+ if k.startswith(f"transformer.resblocks")))
1051
+ model = PromptLearner(transformer_width, context_length, vocab_size,
1052
+ transformer_layers, transformer_heads, embed_dim)
1053
+ # model = PromptLearner(embed_dim, vision_patch_size, context_length, txt_length, vocab_size,
1054
+ # transformer_width, transformer_heads, transformer_layers)
1055
+ load_dict = {}
1056
+ for k, v in state_dict.items():
1057
+ if not k.startswith("visual") and (
1058
+ k not in ["logit_scale", "input_resolution", "context_length", "vocab_size"]):
1059
+ load_dict[k] = v
1060
+
1061
+ convert_weights(model)
1062
+ model.load_state_dict(load_dict, False)
1063
+
1064
+ return model
1065
+
1066
+ def build_model(state_dict: dict, txt_length: int):
1067
+ vit = "visual.proj" in state_dict
1068
+
1069
+ if vit:
1070
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
1071
+ vision_layers = len([
1072
+ k for k in state_dict.keys()
1073
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
1074
+ ])
1075
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
1076
+ grid_size = round(
1077
+ (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
1078
+ image_resolution = vision_patch_size * grid_size
1079
+ else:
1080
+ counts: list = [
1081
+ len(
1082
+ set(
1083
+ k.split(".")[2] for k in state_dict
1084
+ if k.startswith(f"visual.layer{b}")))
1085
+ for b in [1, 2, 3, 4]
1086
+ ]
1087
+ vision_layers = tuple(counts)
1088
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
1089
+ output_width = round(
1090
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] -
1091
+ 1)**0.5)
1092
+ vision_patch_size = None
1093
+ assert output_width**2 + 1 == state_dict[
1094
+ "visual.attnpool.positional_embedding"].shape[0]
1095
+ image_resolution = output_width * 32
1096
+
1097
+ vision_heads = vision_width * 32 // 64
1098
+ embed_dim = state_dict["text_projection"].shape[1]
1099
+ # context_length = state_dict["positional_embedding"].shape[0]
1100
+ context_length = txt_length
1101
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
1102
+ transformer_width = state_dict["ln_final.weight"].shape[0]
1103
+ transformer_heads = transformer_width // 64
1104
+ transformer_layers = len(
1105
+ set(
1106
+ k.split(".")[2] for k in state_dict
1107
+ if k.startswith(f"transformer.resblocks")))
1108
+
1109
+ model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
1110
+ vision_patch_size, context_length, txt_length, vocab_size,
1111
+ transformer_width, transformer_heads, transformer_layers)
1112
+
1113
+ for key in ["input_resolution", "context_length", "vocab_size", 'positional_embedding']:
1114
+ if key in state_dict:
1115
+ del state_dict[key]
1116
+
1117
+ convert_weights(model)
1118
+ model.load_state_dict(state_dict, False)
1119
+ return model.eval(), image_resolution, vision_heads, embed_dim, vision_width, vision_patch_size
1120
+
1121
+ def build_lclip_model(state_dict: dict, load_from_clip: bool):
1122
+ vit = "visual.proj" in state_dict
1123
+
1124
+ if vit:
1125
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
1126
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
1127
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
1128
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
1129
+ image_resolution = vision_patch_size * grid_size
1130
+
1131
+ else:
1132
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
1133
+ vision_layers = tuple(counts)
1134
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
1135
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
1136
+ vision_patch_size = None
1137
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
1138
+ image_resolution = output_width * 32
1139
+
1140
+ embed_dim = state_dict["text_projection"].shape[1]
1141
+ # print(embed_dim)
1142
+ context_length = state_dict["positional_embedding"].shape[0]
1143
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
1144
+ transformer_width = state_dict["ln_final.weight"].shape[0]
1145
+ transformer_heads = transformer_width // 64
1146
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
1147
+
1148
+ model = LCLIP(
1149
+ embed_dim,
1150
+ image_resolution, vision_layers, vision_width, vision_patch_size,
1151
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip
1152
+ )
1153
+
1154
+ for key in ["input_resolution", "context_length", "vocab_size"]:
1155
+ if key in state_dict:
1156
+ del state_dict[key]
1157
+
1158
+ convert_weights(model)
1159
+ # model.load_state_dict(state_dict)
1160
+ model.load_state_dict(state_dict, strict=False)
1161
+ vision_heads = vision_width // 64
1162
+ # print(vision_heads)
1163
+ return model.eval(), image_resolution, vision_heads, embed_dim, vision_width, vision_patch_size
1164
+
1165
+ def build_modified_model(state_dict: dict, txt_length: int):
1166
+ vit = "visual.proj" in state_dict
1167
+
1168
+ if vit:
1169
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
1170
+ vision_layers = len([
1171
+ k for k in state_dict.keys()
1172
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
1173
+ ])
1174
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
1175
+ grid_size = round(
1176
+ (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
1177
+ image_resolution = vision_patch_size * grid_size
1178
+ else:
1179
+ counts: list = [
1180
+ len(
1181
+ set(
1182
+ k.split(".")[2] for k in state_dict
1183
+ if k.startswith(f"visual.layer{b}")))
1184
+ for b in [1, 2, 3, 4]
1185
+ ]
1186
+ vision_layers = tuple(counts)
1187
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
1188
+
1189
+ output_width = round(
1190
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] -
1191
+ 1)**0.5)
1192
+ vision_patch_size = None
1193
+ assert output_width**2 + 1 == state_dict[
1194
+ "visual.attnpool.positional_embedding"].shape[0]
1195
+ image_resolution = output_width * 32
1196
+ embed_dim = state_dict["text_projection"].shape[1]
1197
+
1198
+ model = zhCLIP(embed_dim, image_resolution, vision_layers, vision_width,
1199
+ vision_patch_size)
1200
+
1201
+ for key in ["input_resolution", "context_length", "vocab_size"]:
1202
+ if key in state_dict:
1203
+ del state_dict[key]
1204
+
1205
+ convert_weights(model)
1206
+ model.load_state_dict(state_dict, False)
1207
+ return model.eval()
cisen/model/layers.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ # import open_clip
7
+
8
+ def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
11
+ nn.BatchNorm2d(out_dim), nn.ReLU(True))
12
+ # return nn.Sequential(
13
+ # nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
14
+ # nn.LayerNorm(out_dim), nn.ReLU(True))
15
+
16
+
17
+ # def conv_layer_1(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
18
+ # return nn.Sequential(
19
+ # nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
20
+ # nn.LayerNorm(out_dim), nn.ReLU(True))
21
+
22
+ def linear_layer(in_dim, out_dim,bias=False):
23
+ return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
24
+ nn.BatchNorm1d(out_dim), nn.ReLU(True))
25
+ # return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
26
+ # nn.LayerNorm(out_dim), nn.ReLU(True))
27
+ class AttentionPool2d(nn.Module):
28
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
29
+ super().__init__()
30
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
31
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
32
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
33
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
34
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
35
+ self.num_heads = num_heads
36
+
37
+ def forward(self, x):
38
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
39
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
40
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
41
+ x, _ = F.multi_head_attention_forward(
42
+ query=x[:1], key=x, value=x,
43
+ embed_dim_to_check=x.shape[-1],
44
+ num_heads=self.num_heads,
45
+ q_proj_weight=self.q_proj.weight,
46
+ k_proj_weight=self.k_proj.weight,
47
+ v_proj_weight=self.v_proj.weight,
48
+ in_proj_weight=None,
49
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
50
+ bias_k=None,
51
+ bias_v=None,
52
+ add_zero_attn=False,
53
+ dropout_p=0,
54
+ out_proj_weight=self.c_proj.weight,
55
+ out_proj_bias=self.c_proj.bias,
56
+ use_separate_proj_weight=True,
57
+ training=self.training,
58
+ need_weights=False
59
+ )
60
+ return x.squeeze(0)
61
+
62
+ # class AttentionPool2d(nn.Module):
63
+ # def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
64
+ # super().__init__()
65
+ # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
66
+ # self.k_proj = nn.Linear(embed_dim, embed_dim)
67
+ # self.q_proj = nn.Linear(embed_dim, embed_dim)
68
+ # self.v_proj = nn.Linear(embed_dim, embed_dim)
69
+ # self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
70
+ # self.num_heads = num_heads
71
+ #
72
+ # def forward(self, x):
73
+ # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
74
+ # x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
75
+ # x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
76
+ # x, _ = F.multi_head_attention_forward(
77
+ # query=x, key=x, value=x,
78
+ # embed_dim_to_check=x.shape[-1],
79
+ # num_heads=self.num_heads,
80
+ # q_proj_weight=self.q_proj.weight,
81
+ # k_proj_weight=self.k_proj.weight,
82
+ # v_proj_weight=self.v_proj.weight,
83
+ # in_proj_weight=None,
84
+ # in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
85
+ # bias_k=None,
86
+ # bias_v=None,
87
+ # add_zero_attn=False,
88
+ # dropout_p=0,
89
+ # out_proj_weight=self.c_proj.weight,
90
+ # out_proj_bias=self.c_proj.bias,
91
+ # use_separate_proj_weight=True,
92
+ # training=self.training,
93
+ # need_weights=False
94
+ # )
95
+ #
96
+ # return x[0]
97
+
98
+ class CoordConv(nn.Module):
99
+ def __init__(self,
100
+ in_channels,
101
+ out_channels,
102
+ kernel_size=3,
103
+ padding=1,
104
+ stride=1):
105
+ super().__init__()
106
+ self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
107
+ padding, stride)
108
+
109
+ def add_coord(self, input):
110
+ b, _, h, w = input.size()
111
+ x_range = torch.linspace(-1, 1, w, device=input.device)
112
+ y_range = torch.linspace(-1, 1, h, device=input.device)
113
+ y, x = torch.meshgrid(y_range, x_range)
114
+ y = y.expand([b, 1, -1, -1])
115
+ x = x.expand([b, 1, -1, -1])
116
+ coord_feat = torch.cat([x, y], 1)
117
+ input = torch.cat([input, coord_feat], 1)
118
+ return input
119
+
120
+ def forward(self, x):
121
+ x = self.add_coord(x)
122
+ x = self.conv1(x)
123
+ return x
124
+
125
+ class TransformerDecoder(nn.Module):
126
+ def __init__(self,
127
+ num_layers,
128
+ d_model,
129
+ nhead,
130
+ dim_ffn,
131
+ dropout,
132
+ return_intermediate=False):
133
+ super().__init__()
134
+ self.layers = nn.ModuleList([
135
+ TransformerDecoderLayer(d_model=d_model,
136
+ nhead=nhead,
137
+ dim_feedforward=dim_ffn,
138
+ dropout=dropout) for _ in range(num_layers)
139
+ ])
140
+ self.num_layers = num_layers
141
+ self.norm = nn.LayerNorm(d_model)
142
+ self.return_intermediate = return_intermediate
143
+
144
+ @staticmethod
145
+ def pos1d(d_model, length):
146
+ """
147
+ :param d_model: dimension of the model
148
+ :param length: length of positions
149
+ :return: length*d_model position matrix
150
+ """
151
+ if d_model % 2 != 0:
152
+ raise ValueError("Cannot use sin/cos positional encoding with "
153
+ "odd dim (got dim={:d})".format(d_model))
154
+ pe = torch.zeros(length, d_model)
155
+ position = torch.arange(0, length).unsqueeze(1)
156
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
157
+ -(math.log(10000.0) / d_model)))
158
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
159
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
160
+
161
+ return pe.unsqueeze(1) # n, 1, 512
162
+
163
+ @staticmethod
164
+ def pos2d(d_model, height, width):
165
+ """
166
+ :param d_model: dimension of the model
167
+ :param height: height of the positions
168
+ :param width: width of the positions
169
+ :return: d_model*height*width position matrix
170
+ """
171
+ if d_model % 4 != 0:
172
+ raise ValueError("Cannot use sin/cos positional encoding with "
173
+ "odd dimension (got dim={:d})".format(d_model))
174
+ pe = torch.zeros(d_model, height, width)
175
+ # Each dimension use half of d_model
176
+ d_model = int(d_model / 2)
177
+ div_term = torch.exp(
178
+ torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
179
+ pos_w = torch.arange(0., width).unsqueeze(1)
180
+ pos_h = torch.arange(0., height).unsqueeze(1)
181
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
182
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
183
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
184
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
185
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
186
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
187
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
188
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
189
+
190
+ return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512
191
+
192
+ def forward(self, vis, txt, pad_mask):
193
+ '''
194
+ vis: b, 512, h, w
195
+ txt: b, L, 512
196
+ pad_mask: b, L
197
+ '''
198
+ B, C, H, W = vis.size()
199
+ _, L, D = txt.size()
200
+ # position encoding
201
+ vis_pos = self.pos2d(C, H, W)
202
+ txt_pos = self.pos1d(D, L)
203
+ # reshape & permute
204
+ vis = vis.reshape(B, C, -1).permute(2, 0, 1)
205
+ txt = txt.permute(1, 0, 2)
206
+ # forward
207
+ output = vis
208
+ intermediate = []
209
+ for layer in self.layers:
210
+ output = layer(output, txt, vis_pos, txt_pos, pad_mask)
211
+ if self.return_intermediate:
212
+ # HW, b, 512 -> b, 512, HW
213
+ intermediate.append(self.norm(output).permute(1, 2, 0))
214
+
215
+ if self.norm is not None:
216
+ # HW, b, 512 -> b, 512, HW
217
+ output = self.norm(output).permute(1, 2, 0)
218
+ if self.return_intermediate:
219
+ intermediate.pop()
220
+ intermediate.append(output)
221
+ # [output1, output2, ..., output_n]
222
+ return intermediate
223
+ else:
224
+ # b, 512, HW
225
+ return output
226
+ return output
227
+
228
+
229
+ class TransformerDecoderLayer(nn.Module):
230
+ def __init__(self,
231
+ d_model=512,
232
+ nhead=9,
233
+ dim_feedforward=2048,
234
+ dropout=0.1):
235
+ super().__init__()
236
+ # Normalization Layer
237
+ self.self_attn_norm = nn.LayerNorm(d_model)
238
+ self.cross_attn_norm = nn.LayerNorm(d_model)
239
+ # Attention Layer
240
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
241
+ self.multihead_attn = nn.MultiheadAttention(d_model,
242
+ nhead,
243
+ dropout=dropout,
244
+ kdim=d_model,
245
+ vdim=d_model)
246
+ # FFN
247
+ self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
248
+ nn.ReLU(True), nn.Dropout(dropout),
249
+ nn.LayerNorm(dim_feedforward),
250
+ nn.Linear(dim_feedforward, d_model))
251
+ # LayerNorm & Dropout
252
+ self.norm1 = nn.LayerNorm(d_model)
253
+ self.norm2 = nn.LayerNorm(d_model)
254
+ self.norm3 = nn.LayerNorm(d_model)
255
+ self.dropout1 = nn.Dropout(dropout)
256
+ self.dropout2 = nn.Dropout(dropout)
257
+ self.dropout3 = nn.Dropout(dropout)
258
+
259
+ def with_pos_embed(self, tensor, pos):
260
+ return tensor if pos is None else tensor + pos.to(tensor.device)
261
+
262
+ def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
263
+ '''
264
+ vis: 26*26, b, 512
265
+ txt: L, b, 512
266
+ vis_pos: 26*26, 1, 512
267
+ txt_pos: L, 1, 512
268
+ pad_mask: b, L
269
+ '''
270
+ # Self-Attention
271
+ vis2 = self.norm1(vis)
272
+ q = k = self.with_pos_embed(vis2, vis_pos)
273
+ vis2 = self.self_attn(q, k, value=vis2)[0]
274
+ vis2 = self.self_attn_norm(vis2)
275
+ vis = vis + self.dropout1(vis2)
276
+ # Cross-Attention
277
+ vis2 = self.norm2(vis)
278
+ vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
279
+ key=self.with_pos_embed(txt, txt_pos),
280
+ value=txt,
281
+ key_padding_mask=pad_mask)[0]
282
+ vis2 = self.cross_attn_norm(vis2)
283
+ vis = vis + self.dropout2(vis2)
284
+ # FFN
285
+ vis2 = self.norm3(vis)
286
+ vis2 = self.ffn(vis2)
287
+ vis = vis + self.dropout3(vis2)
288
+ return vis
289
+
290
+ class Text_Projector(nn.Module):
291
+ def __init__(self, args, in_channels=[512, 1024, 1024],
292
+ out_channels=[256, 512, 1024]):
293
+
294
+ super(Text_Projector, self).__init__()
295
+
296
+ self.proj = linear_layer(args, in_channels[2], out_channels[2])
297
+ self.ReLU = nn.ReLU(True)
298
+
299
+ def forward(self, text):
300
+
301
+ text = self.ReLU(text + self.proj(text))
302
+
303
+ return text
304
+
305
+ class Image_Projector(nn.Module):
306
+ def __init__(self, args, in_channels=[512, 1024, 1024],
307
+ out_channels=[256, 512, 1024]):
308
+
309
+ super(Image_Projector, self).__init__()
310
+
311
+ self.proj = linear_layer(args, in_channels[0], out_channels[2])
312
+ self.ReLU = nn.ReLU(True)
313
+
314
+ def forward(self, image):
315
+
316
+ image = self.ReLU(image + self.proj(image))
317
+
318
+ return image
319
+
320
+ class Adapter(nn.Module):
321
+ def __init__(self, c_in, reduction=4):
322
+ super(Adapter, self).__init__()
323
+ self.fc = nn.Sequential(
324
+ nn.Linear(c_in, c_in // reduction, bias=False),
325
+ nn.ReLU(inplace=True),
326
+ nn.Linear(c_in // reduction, c_in, bias=False),
327
+ nn.ReLU(inplace=True)
328
+ )
329
+
330
+ def forward(self, x):
331
+ x = self.fc(x)
332
+ return x
333
+
334
+ class GAP(nn.Module):
335
+ def __init__(self, kernel):
336
+ super(GAP, self).__init__()
337
+ self.k = kernel
338
+ # self.fc = nn.Linear(512, 1024)
339
+ def forward(self, x):
340
+ x = F.adaptive_avg_pool2d(x, self.k)
341
+
342
+ return x.squeeze(-1).squeeze(-1)
343
+
344
+ class AdaptiveSpatialFeatureFusion(nn.Module):
345
+ def __init__(self, args, in_channels=[512, 1024, 1024],
346
+ out_channels=[256, 512, 1024]):
347
+
348
+ super(AdaptiveSpatialFeatureFusion, self).__init__()
349
+ self.weight = nn.LayerNorm(out_channels[2])
350
+ self.proj = linear_layer(args, in_channels[0], out_channels[2])
351
+
352
+ def forward(self, feature_map1, feature_map2):
353
+ # feature_map1 : b, 1024, 1, 1
354
+ # feature_map2 : b, 512, 1, 1
355
+ feature_map2 = self.proj(feature_map2.squeeze(-1).squeeze(-1))
356
+ feature_map1 = feature_map1.squeeze(-1).squeeze(-1)
357
+ weights1 = torch.norm(feature_map1, dim=1).unsqueeze(-1)
358
+ weights2 = torch.norm(feature_map2, dim=1).unsqueeze(-1)
359
+ weights1 = weights1 / (weights1 + weights2)
360
+ weights2 = 1 - weights1
361
+
362
+ fused_feature_map = weights1 * feature_map1 + weights2 * feature_map2
363
+ # b, 1024
364
+ return fused_feature_map
365
+
366
+ class ModifiedAttentionPool2d(nn.Module):
367
+ def __init__(self,
368
+ spacial_dim: int,
369
+ embed_dim: int,
370
+ num_heads: int,
371
+ output_dim: int = None):
372
+ super().__init__()
373
+ self.spacial_dim = spacial_dim
374
+ self.positional_embedding = nn.Parameter(
375
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
376
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
377
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
378
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
379
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
380
+ self.num_heads = num_heads
381
+ # residual
382
+ self.connect = nn.Sequential(
383
+ nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
384
+ nn.BatchNorm2d(output_dim))
385
+
386
+ def resize_pos_embed(self, pos_embed, input_shpae):
387
+ """Resize pos_embed weights.
388
+ Resize pos_embed using bicubic interpolate method.
389
+ Args:
390
+ pos_embed (torch.Tensor): Position embedding weights.
391
+ input_shpae (tuple): Tuple for (downsampled input image height,
392
+ downsampled input image width).
393
+ pos_shape (tuple): The resolution of downsampled origin training
394
+ image.
395
+ mode (str): Algorithm used for upsampling:
396
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
397
+ ``'trilinear'``. Default: ``'nearest'``
398
+ Return:
399
+ torch.Tensor: The resized pos_embed of shape [B, C, L_new]
400
+ """
401
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
402
+ pos_h = pos_w = self.spacial_dim
403
+ cls_token_weight = pos_embed[:, 0]
404
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
405
+ pos_embed_weight = pos_embed_weight.reshape(
406
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
407
+ pos_embed_weight = F.interpolate(pos_embed_weight,
408
+ size=input_shpae,
409
+ align_corners=False,
410
+ mode='bicubic')
411
+ cls_token_weight = cls_token_weight.unsqueeze(1)
412
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
413
+ # pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
414
+ return pos_embed_weight.transpose(-2, -1)
415
+
416
+ def forward(self, x):
417
+ B, C, H, W = x.size()
418
+ res = self.connect(x)
419
+ x = x.reshape(B, C, -1) # NC(HW)
420
+ # x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
421
+ pos_embed = self.positional_embedding.unsqueeze(0)
422
+ pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
423
+ x = x + pos_embed.to(x.dtype) # NC(HW)
424
+ x = x.permute(2, 0, 1) # (HW)NC
425
+ x, _ = F.multi_head_attention_forward(
426
+ query=x,
427
+ key=x,
428
+ value=x,
429
+ embed_dim_to_check=x.shape[-1],
430
+ num_heads=self.num_heads,
431
+ q_proj_weight=self.q_proj.weight,
432
+ k_proj_weight=self.k_proj.weight,
433
+ v_proj_weight=self.v_proj.weight,
434
+ in_proj_weight=None,
435
+ in_proj_bias=torch.cat(
436
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
437
+ bias_k=None,
438
+ bias_v=None,
439
+ add_zero_attn=False,
440
+ dropout_p=0,
441
+ out_proj_weight=self.c_proj.weight,
442
+ out_proj_bias=self.c_proj.bias,
443
+ use_separate_proj_weight=True,
444
+ training=self.training,
445
+ need_weights=False)
446
+ xt = x[0]
447
+ x = x.permute(1, 2, 0).reshape(B, -1, H, W)
448
+ x = x + res
449
+ x = F.relu(x, True)
450
+
451
+ return x, xt
452
+
453
+ # modified
454
+ class FPN(nn.Module):
455
+ def __init__(self, args,
456
+ in_channels=[512, 1024, 1024],
457
+ out_channels=[256, 512, 1024, 1024]):
458
+ super(FPN, self).__init__()
459
+ input_resolution = args.input_size
460
+ heads = args.heads
461
+ output_dim = args.output_dim
462
+ embed_dim = args.emb_dim
463
+ # image projection
464
+ self.attn = ModifiedAttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
465
+ # text projection
466
+ self.txt_proj = linear_layer(args, in_channels[2], out_channels[2])
467
+ # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
468
+ self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
469
+
470
+ self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
471
+ nn.ReLU(True))
472
+
473
+ # fusion 2: v4 & fm -> f_4: b, 512, 26, 26
474
+ self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
475
+ self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
476
+ out_channels[1], 1, 0)
477
+ # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
478
+ self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
479
+ self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
480
+ out_channels[1], 1, 0)
481
+ # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
482
+ self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
483
+ self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
484
+ self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
485
+ # aggregation
486
+ self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
487
+ self.coordconv = nn.Sequential(
488
+ CoordConv(out_channels[1], out_channels[1], 3, 1),
489
+ conv_layer(out_channels[1], out_channels[3], 3, 1))
490
+
491
+ def forward(self, imgs, text):
492
+ # v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13
493
+ v3, v4, v5 = imgs
494
+
495
+ # fusion 1: b, 1024, 13, 13
496
+ # text projection: b, 1024 -> b, 1024
497
+ v5, _ = self.attn(v5)
498
+ text_ = self.txt_proj(text)
499
+ state = text_.unsqueeze(-1).unsqueeze(
500
+ -1)# b, 1024, 1, 1
501
+
502
+ f5 = self.f1_v_proj(v5) # b, 1024, 7, 7
503
+
504
+ f5 = self.norm_layer(f5 * state)
505
+ # fusion 2: b, 512, 26, 26
506
+ f4 = self.f2_v_proj(v4)
507
+ # f4 = f4.repeat(w2,1,1,1)
508
+
509
+ f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
510
+ f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
511
+ # fusion 3: b, 256, 26, 26
512
+ f3 = self.f3_v_proj(v3)
513
+ f3 = F.avg_pool2d(f3, 2, 2)
514
+ # f3 = f3.repeat(w2, 1, 1, 1)
515
+
516
+ f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
517
+ # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
518
+ fq5 = self.f4_proj5(f5)
519
+ fq4 = self.f4_proj4(f4)
520
+ fq3 = self.f4_proj3(f3)
521
+ # query
522
+ fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
523
+ fq = torch.cat([fq3, fq4, fq5], dim=1)
524
+ fq = self.aggr(fq)
525
+ fq = self.coordconv(fq)
526
+ # fqq = fq.reshape(w1, w2, fq.shape[1], fq.shape[2], fq.shape[3])
527
+ # b, 512, 26, 26
528
+
529
+ # elif text.shape[0] != v3.shape[0]:
530
+ #
531
+ # text = self.txt_proj(text)
532
+ # state = text.unsqueeze(-1).unsqueeze(
533
+ # -1) # b, 1024, 1, 1
534
+ # state = state.view(v5.shape[0], int(text.shape[0] / v5.shape[0]), state.shape[1], state.shape[2], state.shape[3])
535
+ #
536
+ # f5 = self.f1_v_proj(v5) # b, 1024, 7, 7
537
+ # f5 = f5.unsqueeze(1)
538
+ # f5_ = f5 * state
539
+ # f5_ = f5_.view(-1, f5.shape[2], f5.shape[3], f5.shape[4])
540
+ # f5 = self.norm_layer(f5_)
541
+ # # fusion 2: b, 512, 26, 26
542
+ # f4 = self.f2_v_proj(v4)
543
+ # # f4 = f4.repeat(w2,1,1,1)
544
+ #
545
+ # f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
546
+ # f4 = f4.repeat(int(f5_.shape[0] / f4.shape[0]), 1, 1, 1)
547
+ # f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
548
+ #
549
+ # # fusion 3: b, 256, 26, 26
550
+ # f3 = self.f3_v_proj(v3)
551
+ # f3 = F.avg_pool2d(f3, 2, 2)
552
+ # # f3 = f3.repeat(w2, 1, 1, 1)
553
+ # f3 = f3.repeat(int(f5_.shape[0] / f3.shape[0]), 1, 1, 1)
554
+ # f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
555
+ # # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
556
+ # fq5 = self.f4_proj5(f5)
557
+ # fq4 = self.f4_proj4(f4)
558
+ # fq3 = self.f4_proj3(f3)
559
+ # # query
560
+ # fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
561
+ # fq = torch.cat([fq3, fq4, fq5], dim=1)
562
+ # fq = self.aggr(fq)
563
+ # fq = self.coordconv(fq)
564
+ return fq
565
+
566
+ class ViTFPN(nn.Module):
567
+ def __init__(self, image_resolution,
568
+ in_channels=[512, 768, 768],
569
+ out_channels=[768, 768, 768, 512]):
570
+ super(ViTFPN, self).__init__()
571
+ # text projection
572
+ self.txt_proj = linear_layer(in_channels[0], out_channels[1])
573
+ # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
574
+ self.f1_v_proj = conv_layer(in_channels[1], out_channels[1], 1, 0)
575
+ self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[1]),
576
+ nn.ReLU(True))
577
+ # fusion 2: v4 & fm -> f_4: b, 512, 26, 26
578
+ self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
579
+ self.f2_cat = conv_layer(out_channels[0] + out_channels[0],
580
+ out_channels[0], 1, 0)
581
+ # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
582
+ self.f3_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
583
+ self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
584
+ out_channels[1], 1, 0)
585
+ # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
586
+ self.f4_proj5 = conv_layer(out_channels[1], out_channels[0], 3, 1)
587
+ self.f4_proj4 = conv_layer(out_channels[0], out_channels[0], 3, 1)
588
+ self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
589
+ # aggregation
590
+ self.aggr = conv_layer(3 * out_channels[0], out_channels[0], 1, 0)
591
+ self.coordconv = nn.Sequential(
592
+ CoordConv(out_channels[0], out_channels[0], 3, 1),
593
+ conv_layer(out_channels[0], out_channels[-1], 3, 1))
594
+
595
+ self.attnpool = AttentionPool2d(image_resolution // 32, out_channels[-1],
596
+ 8, out_channels[-1])
597
+ def forward(self, imgs, state, vis):
598
+ # v1 / v2 / b, 49, 1024/ b, 196, 512
599
+ v3, v4, v5 = imgs
600
+ # fusion 1: b, 1024, 13, 13
601
+ # text projection: b, 1024 -> b, 1024
602
+ state = self.txt_proj(state)
603
+ state = state.unsqueeze(-1).unsqueeze(
604
+ -1)# b, 1024, 1, 1
605
+ f5 = self.f1_v_proj(v5)
606
+ f5 = self.norm_layer(f5 * state)
607
+ # fusion 2: b, 512, 26, 26
608
+ f4 = self.f2_v_proj(v4)
609
+ b, c, h, w = f4.size()
610
+ f5_ = F.interpolate(f5, (h, w), mode='bilinear')
611
+ f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
612
+
613
+ # fusion 3: b, 256, 26, 26
614
+ f3 = self.f3_v_proj(v3)
615
+ f3 = F.avg_pool2d(f3, 2, 2)
616
+ # f3 = f3.repeat(w2, 1, 1, 1)
617
+
618
+ f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
619
+ # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
620
+ fq5 = self.f4_proj5(f5)
621
+ fq4 = self.f4_proj4(f4)
622
+ fq3 = self.f4_proj3(f3)
623
+ # query
624
+ fq5 = F.interpolate(fq5, (h, w), mode='bilinear')
625
+ fq = torch.cat([fq3, fq4, fq5], dim=1)
626
+ fq = self.aggr(fq)
627
+ if not vis:
628
+ fq = self.coordconv(fq)
629
+ fq = self.attnpool(fq)
630
+ # b, 512, 26, 26
631
+ return fq
632
+
633
+
cisen/model/segmenter.py ADDED
@@ -0,0 +1,2045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from .clip import build_model, build_promptlearner, build_modified_model, PromptLearner, build_lclip_model
7
+ from torch.cuda.amp import autocast as autocast
8
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
9
+ from timm.models.layers import variance_scaling_
10
+ from einops import rearrange, repeat
11
+ from loguru import logger
12
+ from transformers import AlignProcessor, AlignModel
13
+ from sklearn.metrics import classification_report
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+ from .layers import FPN, TransformerDecoder, ViTFPN, AdaptiveSpatialFeatureFusion, Text_Projector, Image_Projector, Adapter, GAP
16
+ from cisen.model.clip import CLIP
17
+ def lecun_normal_(tensor):
18
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
19
+
20
+ def trunc_normal_(tensor, mean=0.0, std=1.0):
21
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
22
+
23
+ class CISEN_vit(nn.Module, PyTorchModelHubMixin):
24
+ def __init__(self, cfg):
25
+ super().__init__()
26
+ # Vision & Text Encoder & Label Encoder
27
+ clip_model = torch.jit.load(cfg.clip_pretrain,
28
+ map_location="cpu").eval()
29
+
30
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
31
+ self.backbone = backbone.float()
32
+ self.patch_emb = image_resolution // patch_size
33
+ cfg.image_resolution = image_resolution
34
+ cfg.input_size = image_resolution
35
+ cfg.heads = vision_heads // 32
36
+ cfg.emb_dim = vision_width
37
+ cfg.output_dim = embed_dim
38
+
39
+ # multi-scale adapter
40
+ # Multi-Modal FPN
41
+ self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
42
+ # Fined-grained Fusion
43
+ # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
44
+ # d_model=cfg.vis_dim,
45
+ # nhead=cfg.num_head,
46
+ # dim_ffn=cfg.dim_ffn,
47
+ # dropout=cfg.dropout,
48
+ # return_intermediate=cfg.intermediate)
49
+
50
+ # image-text transformer
51
+ # self.trans = nn.Linear(1024, 1024)
52
+ self.ADP = Adapter(cfg.output_dim, 4)
53
+ # parameter
54
+ self.ratio = cfg.ratio
55
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
56
+ self.share_temperature = True
57
+ self.ce = nn.CrossEntropyLoss()
58
+ self.ms_adaptor = nn.ModuleList(
59
+ [
60
+ nn.Sequential(
61
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
62
+ nn.GroupNorm(32, cfg.emb_dim),
63
+ nn.GELU(),
64
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
65
+ ),
66
+ nn.Sequential(
67
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
68
+ ),
69
+ nn.Sequential(
70
+ nn.Identity(),
71
+ ),
72
+ nn.Sequential(
73
+ nn.MaxPool2d(2),
74
+ ),
75
+
76
+ ]
77
+ )
78
+
79
+ self.ms_adaptor.apply(self.init_adaptor)
80
+ def init_adaptor(self, m):
81
+ if isinstance(m, nn.Conv2d):
82
+ lecun_normal_(m.weight)
83
+ if m.bias is not None:
84
+ nn.init.constant_(m.bias, 0)
85
+ elif isinstance(m, nn.GroupNorm):
86
+ nn.init.constant_(m.bias, 0)
87
+ nn.init.constant_(m.weight, 1.0)
88
+ elif isinstance(m, nn.ConvTranspose2d):
89
+ lecun_normal_(m.weight)
90
+ if m.bias is not None:
91
+ nn.init.zeros_(m.bias)
92
+ # self.fc = nn.Linear(512, cfg.num_classes)
93
+
94
+
95
+ def IT_loss(self, image_features, text_features):
96
+ # b, 1024 / b, 1024
97
+ batch = image_features.shape[0]
98
+ # # normalized features
99
+ image_features = image_features / image_features.norm(dim=-1,
100
+ keepdim=True)
101
+ text_features = text_features / text_features.norm(dim=-1,
102
+ keepdim=True)
103
+
104
+ # cosine similarity as logits
105
+ logit_scale = self.logit_scale.exp()
106
+ logits_per_image = logit_scale * image_features @ text_features.t()
107
+ logits_per_text = logits_per_image.t()
108
+
109
+ # shape = [global_batch_size, global_batch_size]
110
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
111
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
112
+
113
+
114
+ return contrastive_loss
115
+
116
+ def forward(self, img, txt, stage):
117
+
118
+ if stage == '1st':
119
+ '''
120
+ img: b, 3, h, w
121
+ word: b, words
122
+ word_mask: b, words
123
+ mask: b, 1, h, w
124
+ stage: 1st or 2nd stage
125
+ '''
126
+ # padding mask used in decoder
127
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
128
+
129
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
130
+ # word: b, length, 512
131
+ # text: b, 1024
132
+ # image: b, 1024
133
+ vis, image = self.backbone.encode_image(img)
134
+
135
+ word, text = self.backbone.encode_text(txt)
136
+
137
+ x = self.ADP(image)
138
+
139
+ x = self.ratio * x + (1-self.ratio) * image
140
+
141
+ # b, 1024
142
+ # fq_t = self.FPN(vis, x)
143
+ #
144
+ # fv_t = self.gap(fq_t)
145
+
146
+ loss1 = self.IT_loss(x, text)
147
+
148
+ loss = loss1
149
+
150
+ ft = text
151
+ fi = x
152
+ fv = None
153
+ elif stage == '2nd':
154
+ '''
155
+ img: b, 3, h, w
156
+ word: b, words
157
+ word_mask: b, words
158
+ mask: b, 1, h, w
159
+ stage: 1st or 2nd stage
160
+ '''
161
+ # padding mask used in decoder
162
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
163
+
164
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
165
+ # word: b, length, 512
166
+ # text: b, 1024
167
+ # image: b, 1024
168
+ vis, image = self.backbone.encode_image(img)
169
+
170
+ word, text = self.backbone.encode_text(txt)
171
+
172
+ x = self.ADP(image)
173
+
174
+ x = self.ratio * x + (1 - self.ratio) * image
175
+ # Construct multi-scale feats
176
+ vis_trans = []
177
+ for i in range(len(self.ms_adaptor)):
178
+ x_ = rearrange(
179
+ vis[i],
180
+ "b (h w) c -> b c h w",
181
+ h=self.patch_emb,
182
+ w=self.patch_emb,
183
+ ).contiguous()
184
+
185
+ feats = self.ms_adaptor[i](x_)
186
+
187
+ vis_trans.append(feats)
188
+
189
+ # fq = self.FPN(vis, x_t)
190
+ fv_t = self.FPN(vis_trans[1:], x, False)
191
+ # fv_t = self.gap(fq_t)
192
+
193
+ # b, 1024
194
+
195
+ loss2 = self.IT_loss(fv_t, text)
196
+
197
+ loss = (loss2)
198
+ fv = fv_t
199
+ ft = text
200
+ fi = x
201
+
202
+
203
+ return loss, fv, fi, ft
204
+
205
+ def visualize(self, img, txt):
206
+ vis, image = self.backbone.encode_image(img)
207
+ word, text = self.backbone.encode_text(txt)
208
+
209
+ x = self.ADP(image)
210
+
211
+ x = self.ratio * x + (1 - self.ratio) * image
212
+ # Construct multi-scale feats
213
+ vis_trans = []
214
+ for i in range(len(self.ms_adaptor)):
215
+ x_ = rearrange(
216
+ vis[i],
217
+ "b (h w) c -> b c h w",
218
+ h=self.patch_emb,
219
+ w=self.patch_emb,
220
+ ).contiguous()
221
+
222
+ feats = self.ms_adaptor[i](x_)
223
+
224
+ vis_trans.append(feats)
225
+
226
+ # fq = self.FPN(vis, x_t)
227
+ fv_t = self.FPN(vis_trans[1:], x, True)
228
+ ft_t = self.FPN(vis_trans[1:], text, True)
229
+ return vis, fv_t, ft_t
230
+
231
+ class CISEN_rsvit(nn.Module, PyTorchModelHubMixin):
232
+ def __init__(self, cfg):
233
+ super().__init__()
234
+ # Vision & Text Encoder & Label Encoder
235
+ clip_model = torch.load(cfg.clip_pretrain,
236
+ map_location="cpu")
237
+
238
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
239
+ self.backbone = backbone.float()
240
+ self.patch_emb = image_resolution // patch_size
241
+
242
+ cfg.image_resolution = image_resolution
243
+ cfg.input_size = image_resolution
244
+ cfg.heads = vision_heads // 32
245
+ cfg.emb_dim = vision_width
246
+ cfg.output_dim = embed_dim
247
+
248
+ # multi-scale adapter
249
+ # Multi-Modal FPN
250
+ self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
251
+ # Fined-grained Fusion
252
+ # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
253
+ # d_model=cfg.vis_dim,
254
+ # nhead=cfg.num_head,
255
+ # dim_ffn=cfg.dim_ffn,
256
+ # dropout=cfg.dropout,
257
+ # return_intermediate=cfg.intermediate)
258
+
259
+ # image-text transformer
260
+ # self.trans = nn.Linear(1024, 1024)
261
+ self.ADP = Adapter(cfg.output_dim, 4)
262
+ # parameter
263
+ self.ratio = cfg.ratio
264
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
265
+ self.share_temperature = True
266
+ self.ce = nn.CrossEntropyLoss()
267
+ self.ms_adaptor = nn.ModuleList(
268
+ [
269
+ nn.Sequential(
270
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
271
+ nn.GroupNorm(32, cfg.emb_dim),
272
+ nn.GELU(),
273
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
274
+ ),
275
+ nn.Sequential(
276
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
277
+ ),
278
+ nn.Sequential(
279
+ nn.Identity(),
280
+ ),
281
+ nn.Sequential(
282
+ nn.MaxPool2d(2),
283
+ ),
284
+
285
+ ]
286
+ )
287
+
288
+ self.ms_adaptor.apply(self.init_adaptor)
289
+ def init_adaptor(self, m):
290
+ if isinstance(m, nn.Conv2d):
291
+ lecun_normal_(m.weight)
292
+ if m.bias is not None:
293
+ nn.init.constant_(m.bias, 0)
294
+ elif isinstance(m, nn.GroupNorm):
295
+ nn.init.constant_(m.bias, 0)
296
+ nn.init.constant_(m.weight, 1.0)
297
+ elif isinstance(m, nn.ConvTranspose2d):
298
+ lecun_normal_(m.weight)
299
+ if m.bias is not None:
300
+ nn.init.zeros_(m.bias)
301
+ # self.fc = nn.Linear(512, cfg.num_classes)
302
+
303
+
304
+ def IT_loss(self, image_features, text_features):
305
+ # b, 1024 / b, 1024
306
+ batch = image_features.shape[0]
307
+ # # normalized features
308
+ image_features = image_features / image_features.norm(dim=-1,
309
+ keepdim=True)
310
+ text_features = text_features / text_features.norm(dim=-1,
311
+ keepdim=True)
312
+
313
+ # cosine similarity as logits
314
+ logit_scale = self.logit_scale.exp()
315
+ logits_per_image = logit_scale * image_features @ text_features.t()
316
+ logits_per_text = logits_per_image.t()
317
+
318
+ # shape = [global_batch_size, global_batch_size]
319
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
320
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
321
+
322
+
323
+ return contrastive_loss
324
+ def image_encode(self, img):
325
+ vis, image = self.backbone.encode_image(img)
326
+
327
+ x = self.ADP(image)
328
+
329
+ x = self.ratio * x + (1 - self.ratio) * image
330
+ return x
331
+
332
+ def text_encode(self, txt):
333
+
334
+ word, text = self.backbone.encode_text(txt)
335
+
336
+ return text
337
+
338
+ def forward(self, img, txt, stage):
339
+
340
+ if stage == '1st':
341
+ '''
342
+ img: b, 3, h, w
343
+ word: b, words
344
+ word_mask: b, words
345
+ mask: b, 1, h, w
346
+ stage: 1st or 2nd stage
347
+ '''
348
+ # padding mask used in decoder
349
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
350
+
351
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
352
+ # word: b, length, 512
353
+ # text: b, 1024
354
+ # image: b, 1024
355
+ vis, image = self.backbone.encode_image(img)
356
+
357
+ word, text = self.backbone.encode_text(txt)
358
+
359
+ x = self.ADP(image)
360
+
361
+ x = self.ratio * x + (1-self.ratio) * image
362
+
363
+ # b, 1024
364
+ # fq_t = self.FPN(vis, x)
365
+ #
366
+ # fv_t = self.gap(fq_t)
367
+
368
+ loss1 = self.IT_loss(x, text)
369
+
370
+ loss = loss1
371
+
372
+ ft = text
373
+ fi = x
374
+ fv = None
375
+ elif stage == '2nd':
376
+ '''
377
+ img: b, 3, h, w
378
+ word: b, words
379
+ word_mask: b, words
380
+ mask: b, 1, h, w
381
+ stage: 1st or 2nd stage
382
+ '''
383
+ # padding mask used in decoder
384
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
385
+
386
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
387
+ # word: b, length, 512
388
+ # text: b, 1024
389
+ # image: b, 1024
390
+ vis, image = self.backbone.encode_image(img)
391
+
392
+ word, text = self.backbone.encode_text(txt)
393
+
394
+ x = self.ADP(image)
395
+
396
+ x = self.ratio * x + (1 - self.ratio) * image
397
+ # Construct multi-scale feats
398
+ vis_trans = []
399
+ for i in range(len(self.ms_adaptor)):
400
+ x_ = rearrange(
401
+ vis[i],
402
+ "b (h w) c -> b c h w",
403
+ h=self.patch_emb,
404
+ w=self.patch_emb,
405
+ ).contiguous()
406
+
407
+ feats = self.ms_adaptor[i](x_)
408
+
409
+ vis_trans.append(feats)
410
+
411
+ # fq = self.FPN(vis, x_t)
412
+ fv_t = self.FPN(vis_trans[1:], x, False)
413
+ # fv_t = self.gap(fq_t)
414
+
415
+ # b, 1024
416
+
417
+ loss2 = self.IT_loss(fv_t, text)
418
+
419
+ loss = (loss2)
420
+ fv = fv_t
421
+ ft = text
422
+ fi = x
423
+
424
+
425
+ return loss, fv, fi, ft
426
+
427
+ def visualize(self, img):
428
+ vis, image = self.backbone.encode_image(img)
429
+
430
+
431
+ x = self.ADP(image)
432
+
433
+ x = self.ratio * x + (1 - self.ratio) * image
434
+ # Construct multi-scale feats
435
+ vis_trans = []
436
+ for i in range(len(self.ms_adaptor)):
437
+ x_ = rearrange(
438
+ vis[i],
439
+ "b (h w) c -> b c h w",
440
+ h=self.patch_emb,
441
+ w=self.patch_emb,
442
+ ).contiguous()
443
+
444
+ feats = self.ms_adaptor[i](x_)
445
+
446
+ vis_trans.append(feats)
447
+
448
+
449
+ fv_t = self.FPN(vis_trans[1:], x, True)
450
+ return vis, fv_t
451
+
452
+ class CISEN_vit(nn.Module, PyTorchModelHubMixin):
453
+ def __init__(self, cfg):
454
+ super().__init__()
455
+ # Vision & Text Encoder & Label Encoder
456
+ clip_model = torch.jit.load(cfg.clip_pretrain,
457
+ map_location="cpu").eval()
458
+
459
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
460
+ self.backbone = backbone.float()
461
+ self.patch_emb = image_resolution // patch_size
462
+ cfg.image_resolution = image_resolution
463
+ cfg.input_size = image_resolution
464
+ cfg.heads = vision_heads // 32
465
+ cfg.emb_dim = vision_width
466
+ cfg.output_dim = embed_dim
467
+
468
+ # multi-scale adapter
469
+ # Multi-Modal FPN
470
+ self.FPN = ViTFPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
471
+ # Fined-grained Fusion
472
+ # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
473
+ # d_model=cfg.vis_dim,
474
+ # nhead=cfg.num_head,
475
+ # dim_ffn=cfg.dim_ffn,
476
+ # dropout=cfg.dropout,
477
+ # return_intermediate=cfg.intermediate)
478
+
479
+ # image-text transformer
480
+ # self.trans = nn.Linear(1024, 1024)
481
+ self.ADP = Adapter(cfg.output_dim, 4)
482
+ # parameter
483
+ self.ratio = cfg.ratio
484
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
485
+ self.share_temperature = True
486
+ self.ce = nn.CrossEntropyLoss()
487
+ self.ms_adaptor = nn.ModuleList(
488
+ [
489
+ nn.Sequential(
490
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
491
+ nn.GroupNorm(32, cfg.emb_dim),
492
+ nn.GELU(),
493
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
494
+ ),
495
+ nn.Sequential(
496
+ nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
497
+ ),
498
+ nn.Sequential(
499
+ nn.Identity(),
500
+ ),
501
+ nn.Sequential(
502
+ nn.MaxPool2d(2),
503
+ ),
504
+
505
+ ]
506
+ )
507
+
508
+ self.ms_adaptor.apply(self.init_adaptor)
509
+ def init_adaptor(self, m):
510
+ if isinstance(m, nn.Conv2d):
511
+ lecun_normal_(m.weight)
512
+ if m.bias is not None:
513
+ nn.init.constant_(m.bias, 0)
514
+ elif isinstance(m, nn.GroupNorm):
515
+ nn.init.constant_(m.bias, 0)
516
+ nn.init.constant_(m.weight, 1.0)
517
+ elif isinstance(m, nn.ConvTranspose2d):
518
+ lecun_normal_(m.weight)
519
+ if m.bias is not None:
520
+ nn.init.zeros_(m.bias)
521
+ # self.fc = nn.Linear(512, cfg.num_classes)
522
+
523
+
524
+ def IT_loss(self, image_features, text_features):
525
+ # b, 1024 / b, 1024
526
+ batch = image_features.shape[0]
527
+ # # normalized features
528
+ image_features = image_features / image_features.norm(dim=-1,
529
+ keepdim=True)
530
+ text_features = text_features / text_features.norm(dim=-1,
531
+ keepdim=True)
532
+
533
+ # cosine similarity as logits
534
+ logit_scale = self.logit_scale.exp()
535
+ logits_per_image = logit_scale * image_features @ text_features.t()
536
+ logits_per_text = logits_per_image.t()
537
+
538
+ # shape = [global_batch_size, global_batch_size]
539
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
540
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
541
+
542
+
543
+ return contrastive_loss
544
+
545
+ def forward(self, img, txt, stage):
546
+
547
+ if stage == '1st':
548
+ '''
549
+ img: b, 3, h, w
550
+ word: b, words
551
+ word_mask: b, words
552
+ mask: b, 1, h, w
553
+ stage: 1st or 2nd stage
554
+ '''
555
+ # padding mask used in decoder
556
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
557
+
558
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
559
+ # word: b, length, 512
560
+ # text: b, 1024
561
+ # image: b, 1024
562
+ vis, image = self.backbone.encode_image(img)
563
+
564
+ word, text = self.backbone.encode_text(txt)
565
+
566
+ x = self.ADP(image)
567
+
568
+ x = self.ratio * x + (1-self.ratio) * image
569
+
570
+ # b, 1024
571
+ # fq_t = self.FPN(vis, x)
572
+ #
573
+ # fv_t = self.gap(fq_t)
574
+
575
+ loss1 = self.IT_loss(x, text)
576
+
577
+ loss = loss1
578
+
579
+ ft = text
580
+ fi = x
581
+ fv = None
582
+ elif stage == '2nd':
583
+ '''
584
+ img: b, 3, h, w
585
+ word: b, words
586
+ word_mask: b, words
587
+ mask: b, 1, h, w
588
+ stage: 1st or 2nd stage
589
+ '''
590
+ # padding mask used in decoder
591
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
592
+
593
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
594
+ # word: b, length, 512
595
+ # text: b, 1024
596
+ # image: b, 1024
597
+ vis, image = self.backbone.encode_image(img)
598
+
599
+ word, text = self.backbone.encode_text(txt)
600
+
601
+ x = self.ADP(image)
602
+
603
+ x = self.ratio * x + (1 - self.ratio) * image
604
+ # Construct multi-scale feats
605
+ vis_trans = []
606
+ for i in range(len(self.ms_adaptor)):
607
+ x_ = rearrange(
608
+ vis[i],
609
+ "b (h w) c -> b c h w",
610
+ h=self.patch_emb,
611
+ w=self.patch_emb,
612
+ ).contiguous()
613
+
614
+ feats = self.ms_adaptor[i](x_)
615
+
616
+ vis_trans.append(feats)
617
+
618
+ # fq = self.FPN(vis, x_t)
619
+ fv_t = self.FPN(vis_trans[1:], x, False)
620
+ # fv_t = self.gap(fq_t)
621
+
622
+ # b, 1024
623
+
624
+ loss2 = self.IT_loss(fv_t, text)
625
+
626
+ loss = (loss2)
627
+ fv = fv_t
628
+ ft = text
629
+ fi = x
630
+
631
+
632
+ return loss, fv, fi, ft
633
+
634
+ def visualize(self, img, txt):
635
+ vis, image = self.backbone.encode_image(img)
636
+ word, text = self.backbone.encode_text(txt)
637
+
638
+ x = self.ADP(image)
639
+
640
+ x = self.ratio * x + (1 - self.ratio) * image
641
+ # Construct multi-scale feats
642
+ vis_trans = []
643
+ for i in range(len(self.ms_adaptor)):
644
+ x_ = rearrange(
645
+ vis[i],
646
+ "b (h w) c -> b c h w",
647
+ h=self.patch_emb,
648
+ w=self.patch_emb,
649
+ ).contiguous()
650
+
651
+ feats = self.ms_adaptor[i](x_)
652
+
653
+ vis_trans.append(feats)
654
+
655
+ # fq = self.FPN(vis, x_t)
656
+ fv_t = self.FPN(vis_trans[1:], x, True)
657
+ ft_t = self.FPN(vis_trans[1:], text, True)
658
+ return vis, fv_t, ft_t
659
+
660
+ class CISEN_rsvit_classification(nn.Module):
661
+ def __init__(self, cfg):
662
+ super().__init__()
663
+ # Vision & Text Encoder & Label Encoder
664
+ clip_model = torch.load(cfg.clip_pretrain,
665
+ map_location="cpu")
666
+
667
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
668
+ self.backbone = backbone.float()
669
+ self.patch_emb = image_resolution // patch_size
670
+ num_classes_fc = 512
671
+ num_classes_output = 10
672
+ self.num_classes_fc = num_classes_fc # Number of classes for fully connected layer
673
+ self.num_classes_output = num_classes_output # Number of classes for output layer
674
+
675
+ # Add a fully connected layer
676
+ self.fc = nn.Linear(in_features=cfg.vis_dim, out_features=num_classes_fc)
677
+
678
+ # Add an output layer for multi-label classification
679
+ self.output_layer = nn.Linear(in_features=num_classes_fc, out_features=num_classes_output)
680
+ self.criterion = nn.BCEWithLogitsLoss()
681
+ cfg.image_resolution = image_resolution
682
+ cfg.input_size = image_resolution
683
+ cfg.heads = vision_heads // 32
684
+ cfg.emb_dim = vision_width
685
+ cfg.output_dim = embed_dim
686
+
687
+
688
+ def IT_loss(self, labels, labels_pre):
689
+
690
+ labels = labels.squeeze(1)
691
+
692
+ loss = self.criterion(labels_pre, labels)
693
+ return loss
694
+
695
+ def forward(self, img, labels):
696
+ _, image_features = self.backbone.encode_image(img)
697
+ # Fully connected layer
698
+ fc_output = self.fc(image_features)
699
+ # Apply ReLU activation function
700
+ fc_output = F.relu(fc_output)
701
+ # Output layer for multi-label classification
702
+
703
+ labels_pre = self.output_layer(fc_output)
704
+
705
+ loss2 = self.IT_loss(labels, labels_pre)
706
+
707
+ return labels_pre, loss2
708
+
709
+
710
+ class CISEN_new(nn.Module):
711
+ def __init__(self, cfg):
712
+ super().__init__()
713
+ # Vision & Text Encoder & Label Encoder
714
+ clip_model = torch.jit.load(cfg.clip_pretrain,
715
+ map_location="cpu").eval()
716
+
717
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_model(clip_model.state_dict(), cfg.word_len)
718
+ self.backbone = backbone.float()
719
+ cfg.input_size = image_resolution
720
+ cfg.heads = vision_heads
721
+ cfg.emb_dim = vision_width * 32
722
+ cfg.output_dim = embed_dim
723
+ # Multi-Modal FPN
724
+ self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
725
+ # Fined-grained Fusion
726
+ # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
727
+ # d_model=cfg.vis_dim,
728
+ # nhead=cfg.num_head,
729
+ # dim_ffn=cfg.dim_ffn,
730
+ # dropout=cfg.dropout,
731
+ # return_intermediate=cfg.intermediate)
732
+
733
+ # image-text transformer
734
+ # self.trans = nn.Linear(1024, 1024)
735
+ self.ADP = Adapter(cfg.output_dim, 4)
736
+ self.gap = GAP((1,1))
737
+ # parameter
738
+ self.ratio = cfg.ratio
739
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
740
+ self.share_temperature = True
741
+ self.margin = 1
742
+ self.eps = 1e-3
743
+ self.ce = nn.CrossEntropyLoss()
744
+ #1st stage
745
+ self.lamda1 = cfg.lamda1
746
+ self.lamda2 = cfg.lamda2
747
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
748
+ # self.fc = nn.Linear(512, cfg.num_classes)
749
+
750
+
751
+ def IT_loss(self, image_features, text_features):
752
+ # b, 1024 / b, 1024
753
+ batch = image_features.shape[0]
754
+ # # normalized features
755
+ image_features = image_features / image_features.norm(dim=-1,
756
+ keepdim=True)
757
+ text_features = text_features / text_features.norm(dim=-1,
758
+ keepdim=True)
759
+
760
+ # cosine similarity as logits
761
+ logit_scale = self.logit_scale.exp()
762
+ logits_per_image = logit_scale * image_features @ text_features.t()
763
+ logits_per_text = logits_per_image.t()
764
+
765
+ # shape = [global_batch_size, global_batch_size]
766
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
767
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
768
+
769
+
770
+ return contrastive_loss
771
+
772
+ def forward(self, img, txt, stage):
773
+
774
+ if stage == '1st':
775
+ '''
776
+ img: b, 3, h, w
777
+ word: b, words
778
+ word_mask: b, words
779
+ mask: b, 1, h, w
780
+ stage: 1st or 2nd stage
781
+ '''
782
+ # padding mask used in decoder
783
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
784
+
785
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
786
+ # word: b, length, 512
787
+ # text: b, 1024
788
+ # image: b, 1024
789
+ vis, image = self.backbone.encode_image(img)
790
+
791
+ word, text = self.backbone.encode_text(txt)
792
+
793
+ x = self.ADP(image)
794
+
795
+ x = self.ratio * x + (1-self.ratio) * image
796
+
797
+ # b, 1024
798
+ # fq_t = self.FPN(vis, x)
799
+ #
800
+ # fv_t = self.gap(fq_t)
801
+
802
+ loss1 = self.IT_loss(x, text)
803
+
804
+ loss = loss1
805
+
806
+ ft = text
807
+ fi = x
808
+ fv = None
809
+ elif stage == '2nd':
810
+ '''
811
+ img: b, 3, h, w
812
+ word: b, words
813
+ word_mask: b, words
814
+ mask: b, 1, h, w
815
+ stage: 1st or 2nd stage
816
+ '''
817
+ # padding mask used in decoder
818
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
819
+
820
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
821
+ # word: b, length, 512
822
+ # text: b, 1024
823
+ # image: b, 1024
824
+ vis, image = self.backbone.encode_image(img)
825
+
826
+ word, text = self.backbone.encode_text(txt)
827
+
828
+ x = self.ADP(image)
829
+
830
+ x = self.ratio * x + (1 - self.ratio) * image
831
+
832
+ # x_t = self.trans(x)
833
+ # fq = self.FPN(vis, x_t)
834
+ fq_t = self.FPN(vis, x)
835
+
836
+ fv_t = self.gap(fq_t)
837
+
838
+ # b, 1024
839
+
840
+ loss2 = self.IT_loss(fv_t, text)
841
+
842
+ loss = (loss2)
843
+ fv = fv_t
844
+ ft = text
845
+ fi = x
846
+ elif stage == '3rd':
847
+ '''
848
+ img: b, 3, h, w
849
+ word: b, words
850
+ word_mask: b, words
851
+ mask: b, 1, h, w
852
+ stage: 1st or 2nd stage
853
+ '''
854
+ # padding mask used in decoder
855
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
856
+
857
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
858
+ # word: b, length, 512
859
+ # text: b, 1024
860
+ # image: b, 1024
861
+ vis, image = self.backbone.encode_image(img)
862
+
863
+ word, text = self.backbone.encode_text(txt)
864
+
865
+ x = self.ADP(text)
866
+ ratio = 0.2
867
+ x = ratio * x + (1 - ratio) * text
868
+
869
+ # x_t = self.trans(x)
870
+ # fq = self.FPN(vis, x_t)
871
+
872
+ # b, 1024
873
+ loss1 = self.IT_loss(image, x)
874
+
875
+
876
+ loss = loss1
877
+ fv = None
878
+ ft = x
879
+ fi = image
880
+ elif stage == '4th':
881
+ '''
882
+ img: b, 3, h, w
883
+ word: b, words
884
+ word_mask: b, words
885
+ mask: b, 1, h, w
886
+ stage: 1st or 2nd stage
887
+ '''
888
+ # padding mask used in decoder
889
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
890
+
891
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
892
+ # word: b, length, 512
893
+ # text: b, 1024
894
+ # image: b, 1024
895
+ vis, image = self.backbone.encode_image(img)
896
+ word, text = self.backbone.encode_text(txt)
897
+ # x = self.ADP(image)
898
+ # ratio = 0.2
899
+ # x = ratio * x + (1 - ratio) * text
900
+ fq_t = self.FPN(vis, image)
901
+
902
+ fv_t = self.gap(fq_t)
903
+ ratio_1 = 0.2
904
+ # b, 1024
905
+ loss2 = self.IT_loss(fv_t, text)
906
+
907
+ loss = loss2
908
+ fv = fv_t
909
+ fi = None
910
+ ft = text
911
+ elif stage == '5th':
912
+ '''
913
+ img: b, 3, h, w
914
+ word: b, words
915
+ word_mask: b, words
916
+ mask: b, 1, h, w
917
+ stage: 1st or 2nd stage
918
+ '''
919
+ # padding mask used in decoder
920
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
921
+
922
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
923
+ # word: b, length, 512
924
+ # text: b, 1024
925
+ # image: b, 1024
926
+ vis, image = self.backbone.encode_image(img)
927
+ word, text = self.backbone.encode_text(txt)
928
+ x = self.ADP(image)
929
+ ratio = 0.2
930
+ x = ratio * x + (1 - ratio) * image
931
+
932
+ y = self.ADP_t(text)
933
+ ratio_1 = 0.2
934
+ y = ratio * y + (1 - ratio_1) * text
935
+
936
+ fq_t = self.FPN(vis, image)
937
+
938
+ fv_t = self.gap(fq_t)
939
+
940
+
941
+ # b, 1024
942
+
943
+ loss2 = self.IT_loss(fv_t, y)
944
+
945
+ loss = loss2
946
+ fv = fv_t
947
+ fi = x
948
+ ft = y
949
+
950
+ return loss, fv, fi, ft
951
+
952
+ class CISEN_lclip(nn.Module):
953
+ def __init__(self, cfg):
954
+ super().__init__()
955
+ # Vision & Text Encoder & Label Encoder
956
+ clip_model = torch.load(cfg.clip_pretrain,
957
+ map_location="cpu")
958
+ # print(type(clip_model))
959
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_lclip_model(clip_model, load_from_clip=True)
960
+ self.backbone = backbone.float()
961
+ cfg.input_size = image_resolution
962
+ cfg.heads = vision_heads // 32
963
+ cfg.emb_dim = vision_width
964
+ cfg.output_dim = embed_dim
965
+ # Multi-Modal FPN
966
+ self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
967
+ # Fined-grained Fusion
968
+ # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
969
+ # d_model=cfg.vis_dim,
970
+ # nhead=cfg.num_head,
971
+ # dim_ffn=cfg.dim_ffn,
972
+ # dropout=cfg.dropout,
973
+ # return_intermediate=cfg.intermediate)
974
+
975
+ # image-text transformer
976
+ # self.trans = nn.Linear(1024, 1024)
977
+ self.ADP = Adapter(cfg.output_dim, 4)
978
+ self.gap = GAP((1,1))
979
+ # parameter
980
+ self.ratio = cfg.ratio
981
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
982
+ self.share_temperature = True
983
+ self.margin = 1
984
+ self.eps = 1e-3
985
+ self.ce = nn.CrossEntropyLoss()
986
+ #1st stage
987
+ self.lamda1 = cfg.lamda1
988
+ self.lamda2 = cfg.lamda2
989
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
990
+ # self.fc = nn.Linear(512, cfg.num_classes)
991
+
992
+
993
+ def IT_loss(self, image_features, text_features):
994
+ # b, 1024 / b, 1024
995
+ batch = image_features.shape[0]
996
+ # # normalized features
997
+ image_features = image_features / image_features.norm(dim=-1,
998
+ keepdim=True)
999
+ text_features = text_features / text_features.norm(dim=-1,
1000
+ keepdim=True)
1001
+
1002
+ # cosine similarity as logits
1003
+ logit_scale = self.logit_scale.exp()
1004
+ logits_per_image = logit_scale * image_features @ text_features.t()
1005
+ logits_per_text = logits_per_image.t()
1006
+
1007
+ # shape = [global_batch_size, global_batch_size]
1008
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
1009
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
1010
+
1011
+
1012
+ return contrastive_loss
1013
+
1014
+ def forward(self, img, txt, stage):
1015
+
1016
+ if stage == '1st':
1017
+ '''
1018
+ img: b, 3, h, w
1019
+ word: b, words
1020
+ word_mask: b, words
1021
+ mask: b, 1, h, w
1022
+ stage: 1st or 2nd stage
1023
+ '''
1024
+ # padding mask used in decoder
1025
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1026
+
1027
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1028
+ # word: b, length, 512
1029
+ # text: b, 1024
1030
+ # image: b, 1024
1031
+ vis, image = self.backbone.encode_image(img)
1032
+
1033
+ text = self.backbone.encode_text(txt)
1034
+
1035
+ x = self.ADP(image)
1036
+
1037
+ x = self.ratio * x + (1-self.ratio) * image
1038
+
1039
+ # b, 1024
1040
+ # fq_t = self.FPN(vis, x)
1041
+ #
1042
+ # fv_t = self.gap(fq_t)
1043
+
1044
+ loss1 = self.IT_loss(x, text)
1045
+
1046
+ loss = loss1
1047
+
1048
+ ft = text
1049
+ fi = x
1050
+ fv = None
1051
+ elif stage == '2nd':
1052
+ '''
1053
+ img: b, 3, h, w
1054
+ word: b, words
1055
+ word_mask: b, words
1056
+ mask: b, 1, h, w
1057
+ stage: 1st or 2nd stage
1058
+ '''
1059
+ # padding mask used in decoder
1060
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1061
+
1062
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1063
+ # word: b, length, 512
1064
+ # text: b, 1024
1065
+ # image: b, 1024
1066
+ vis, image = self.backbone.encode_image(img)
1067
+
1068
+ word, text = self.backbone.encode_text(txt)
1069
+
1070
+ x = self.ADP(image)
1071
+
1072
+ x = self.ratio * x + (1 - self.ratio) * image
1073
+
1074
+ # x_t = self.trans(x)
1075
+ # fq = self.FPN(vis, x_t)
1076
+ fq_t = self.FPN(vis, x)
1077
+
1078
+ fv_t = self.gap(fq_t)
1079
+
1080
+ # b, 1024
1081
+
1082
+ loss2 = self.IT_loss(fv_t, text)
1083
+
1084
+ loss = (loss2)
1085
+ fv = fv_t
1086
+ ft = text
1087
+ fi = x
1088
+ elif stage == '3rd':
1089
+ '''
1090
+ img: b, 3, h, w
1091
+ word: b, words
1092
+ word_mask: b, words
1093
+ mask: b, 1, h, w
1094
+ stage: 1st or 2nd stage
1095
+ '''
1096
+ # padding mask used in decoder
1097
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1098
+
1099
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1100
+ # word: b, length, 512
1101
+ # text: b, 1024
1102
+ # image: b, 1024
1103
+ vis, image = self.backbone.encode_image(img)
1104
+
1105
+ text = self.backbone.encode_text(txt)
1106
+
1107
+ x = self.ADP(text)
1108
+ ratio = 0.2
1109
+ x = ratio * x + (1 - ratio) * text
1110
+
1111
+ # x_t = self.trans(x)
1112
+ # fq = self.FPN(vis, x_t)
1113
+
1114
+ # b, 1024
1115
+ loss1 = self.IT_loss(image, x)
1116
+
1117
+
1118
+ loss = loss1
1119
+ fv = None
1120
+ ft = x
1121
+ fi = image
1122
+ elif stage == '4th':
1123
+ '''
1124
+ img: b, 3, h, w
1125
+ word: b, words
1126
+ word_mask: b, words
1127
+ mask: b, 1, h, w
1128
+ stage: 1st or 2nd stage
1129
+ '''
1130
+ # padding mask used in decoder
1131
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1132
+
1133
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1134
+ # word: b, length, 512
1135
+ # text: b, 1024
1136
+ # image: b, 1024
1137
+ vis, image = self.backbone.encode_image(img)
1138
+ word, text = self.backbone.encode_text(txt)
1139
+ # x = self.ADP(image)
1140
+ # ratio = 0.2
1141
+ # x = ratio * x + (1 - ratio) * text
1142
+ fq_t = self.FPN(vis, image)
1143
+
1144
+ fv_t = self.gap(fq_t)
1145
+ ratio_1 = 0.2
1146
+ # b, 1024
1147
+ loss2 = self.IT_loss(fv_t, text)
1148
+
1149
+ loss = loss2
1150
+ fv = fv_t
1151
+ fi = None
1152
+ ft = text
1153
+ elif stage == '5th':
1154
+ '''
1155
+ img: b, 3, h, w
1156
+ word: b, words
1157
+ word_mask: b, words
1158
+ mask: b, 1, h, w
1159
+ stage: 1st or 2nd stage
1160
+ '''
1161
+ # padding mask used in decoder
1162
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1163
+
1164
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1165
+ # word: b, length, 512
1166
+ # text: b, 1024
1167
+ # image: b, 1024
1168
+ vis, image = self.backbone.encode_image(img)
1169
+ word, text = self.backbone.encode_text(txt)
1170
+ x = self.ADP(image)
1171
+ ratio = 0.2
1172
+ x = ratio * x + (1 - ratio) * image
1173
+
1174
+ y = self.ADP_t(text)
1175
+ ratio_1 = 0.2
1176
+ y = ratio * y + (1 - ratio_1) * text
1177
+
1178
+ fq_t = self.FPN(vis, image)
1179
+
1180
+ fv_t = self.gap(fq_t)
1181
+
1182
+
1183
+ # b, 1024
1184
+
1185
+ loss2 = self.IT_loss(fv_t, y)
1186
+
1187
+ loss = loss2
1188
+ fv = fv_t
1189
+ fi = x
1190
+ ft = y
1191
+
1192
+ return loss, fv, fi, ft
1193
+
1194
+ class GeoRSCLIP(nn.Module):
1195
+ def __init__(self, cfg):
1196
+ super().__init__()
1197
+ # Vision & Text Encoder & Label Encoder
1198
+ clip_model = torch.load(cfg.clip_pretrain,
1199
+ map_location="cpu")
1200
+
1201
+ backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
1202
+ self.backbone = backbone.float()
1203
+
1204
+ def forward(self, img, txt, stage):
1205
+
1206
+
1207
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1208
+
1209
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1210
+ # word: b, length, 512
1211
+ # text: b, 1024
1212
+ # image: b, 1024
1213
+ vis, image = self.backbone.encode_image(img)
1214
+
1215
+ word, text = self.backbone.encode_text(txt)
1216
+
1217
+ loss = None
1218
+
1219
+ ft = text
1220
+ fi = image
1221
+ fv = None
1222
+ return loss, fv, fi, ft
1223
+
1224
+ class CISEN(nn.Module):
1225
+ def __init__(self, cfg):
1226
+ super().__init__()
1227
+ # Vision & Text Encoder & Label Encoder
1228
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1229
+ map_location="cpu").eval()
1230
+
1231
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
1232
+ # Multi-Modal FPN
1233
+ self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1234
+ # Fined-grained Fusion
1235
+ self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
1236
+ d_model=cfg.vis_dim,
1237
+ nhead=cfg.num_head,
1238
+ dim_ffn=cfg.dim_ffn,
1239
+ dropout=cfg.dropout,
1240
+ return_intermediate=cfg.intermediate)
1241
+ # adaptively aggretation
1242
+ self.ASFF = AdaptiveSpatialFeatureFusion(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1243
+ # text projector
1244
+ self.projT = Text_Projector(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1245
+ # image projector
1246
+ # self.projI = Image_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1247
+ # parameter
1248
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1249
+ self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1250
+ self.share_temperature = True
1251
+ self.margin = 1
1252
+ self.eps = 1e-3
1253
+ self.ce = nn.CrossEntropyLoss()
1254
+ #1st stage
1255
+ self.lamda1 = cfg.lamda1
1256
+ self.lamda2 = cfg.lamda2
1257
+ self.beta1 = cfg.beta1
1258
+ self.beta2 = cfg.beta2
1259
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
1260
+ # self.fc = nn.Linear(512, cfg.num_classes)
1261
+ #2nd stage
1262
+ self.pos_samples = cfg.pos_samples
1263
+ self.neg_samples = cfg.neg_samples
1264
+
1265
+ def IT_loss(self, image_features, text_features):
1266
+ # b, 1024 / b, 1024
1267
+ batch = image_features.shape[0]
1268
+ # # normalized features
1269
+ image_features = image_features / image_features.norm(dim=-1,
1270
+ keepdim=True)
1271
+ text_features = text_features / text_features.norm(dim=-1,
1272
+ keepdim=True)
1273
+
1274
+ # cosine similarity as logits
1275
+ logit_scale = self.logit_scale.exp()
1276
+ logits_per_image = logit_scale * image_features @ text_features.t()
1277
+ logits_per_text = logits_per_image.t()
1278
+
1279
+ # shape = [global_batch_size, global_batch_size]
1280
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
1281
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
1282
+
1283
+
1284
+ return contrastive_loss
1285
+
1286
+ def IET_loss(self, image_features, text_features, pos_samples, beta):
1287
+ # b, 1024 / b, 1024
1288
+ # # normalized features
1289
+ image_features = [image_feature / image_feature.norm(dim=-1,
1290
+ keepdim=True) for image_feature in image_features]
1291
+ text_features = text_features / text_features.norm(dim=-1,
1292
+ keepdim=True)
1293
+
1294
+ # cosine similarity as logits
1295
+ logit_scale = self.logit_scale.exp()
1296
+
1297
+ # logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features]
1298
+ logits_per_image = [logit_scale * torch.sum(torch.mul(image_feature, text_features),1) for image_feature in image_features]
1299
+ logits_per_image = torch.stack(logits_per_image).t()
1300
+ b = logits_per_image.shape[0]
1301
+ loss1 = torch.norm(text_features - image_features[0])
1302
+ positive_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
1303
+ negative_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
1304
+ positive_tagsT[:, 0 : pos_samples + 1] = 1
1305
+ negative_tagsT[:, pos_samples + 1 : -1] = 1
1306
+
1307
+ maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
1308
+ pos_score_matT = logits_per_image * positive_tagsT
1309
+ neg_score_matT = logits_per_image * negative_tagsT
1310
+ IW_pos3T = pos_score_matT.unsqueeze(1)
1311
+ IW_neg3T = neg_score_matT.unsqueeze(-1)
1312
+ OT = 1 + IW_neg3T - IW_pos3T
1313
+ O_maskT = maskT * OT
1314
+ diffT = torch.clamp(O_maskT, 0)
1315
+ violationT = torch.sign(diffT).sum(1).sum(1)
1316
+ diffT = diffT.sum(1).sum(1)
1317
+ lossT = torch.mean(diffT / (violationT + self.eps))
1318
+ loss = beta * loss1 + lossT
1319
+
1320
+ return loss
1321
+
1322
+ def test_IET_loss(self, image_features, text_features, pos_samples, beta1, beta2):
1323
+ # text_features: enhanced_features
1324
+ # b, 1024 / b, 1024
1325
+ # # normalized features
1326
+ image_features = image_features / image_features.norm(dim=-1,
1327
+ keepdim=True)
1328
+ text_features = text_features / text_features.norm(dim=-1,
1329
+ keepdim=True)
1330
+ image_features = image_features.unsqueeze(1)
1331
+ # cosine similarity as logits
1332
+ logit_scale = self.logit_scale.exp()
1333
+ # image_features = image_features.expand(-1, text_features.shape[1], -1)
1334
+ logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
1335
+ logits_per_image = logits_per_image.squeeze(1)
1336
+ # logits_per_image = logit_scale * image_features @ text_features.t()
1337
+ # logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features]
1338
+
1339
+ b = logits_per_image.shape[0]
1340
+
1341
+ # loss1 = torch.norm(text_features[:, 0, :] - image_features.squeeze(1))
1342
+
1343
+ positive_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
1344
+ negative_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
1345
+ positive_tagsT[:, 0 : pos_samples + 1] = 1
1346
+ negative_tagsT[:, pos_samples + 1 : -1] = 1
1347
+
1348
+ maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
1349
+ pos_score_matT = logits_per_image * positive_tagsT
1350
+ neg_score_matT = logits_per_image * negative_tagsT
1351
+ IW_pos3T = pos_score_matT.unsqueeze(1)
1352
+ IW_neg3T = neg_score_matT.unsqueeze(-1)
1353
+ OT = 1 + IW_neg3T - IW_pos3T
1354
+ O_maskT = maskT * OT
1355
+ diffT = torch.clamp(O_maskT, 0)
1356
+ violationT = torch.sign(diffT).sum(1).sum(1)
1357
+ diffT = diffT.sum(1).sum(1)
1358
+ lossT = torch.mean(diffT / (violationT + self.eps))
1359
+ # loss = beta1 * loss1 + beta2 * lossT
1360
+ loss = lossT
1361
+ return loss
1362
+
1363
+ def test_IT_loss(self, image_features, text_features):
1364
+ # b, 1024 / b, 1024
1365
+ batch = image_features.shape[0]
1366
+ # # normalized features
1367
+ image_features = image_features / image_features.norm(dim=-1,
1368
+ keepdim=True)
1369
+ text_features = text_features / text_features.norm(dim=-1,
1370
+ keepdim=True)
1371
+ image_features = image_features.unsqueeze(1)
1372
+ # cosine similarity as logits
1373
+ logit_scale = self.logit_scale.exp()
1374
+ logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
1375
+ logits_per_image = logits_per_image.squeeze(1)
1376
+
1377
+ # shape = [global_batch_size, global_batch_size]
1378
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
1379
+ contrastive_loss = self.ce(logits_per_image, contrastive_labels)
1380
+
1381
+
1382
+ return contrastive_loss
1383
+
1384
+ def test_forward(self, img, txt):
1385
+ '''
1386
+ img: b, 3, h, w
1387
+ word: b, words
1388
+ word_mask: b, words
1389
+ mask: b, 1, h, w
1390
+ stage: 1st or 2nd stage
1391
+ '''
1392
+
1393
+ # padding mask used in decoder
1394
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1395
+
1396
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1397
+ # word: b, length, 512
1398
+ # state: b, 1024
1399
+ # image: b, 512
1400
+ vis, image = self.backbone.encode_image(img)
1401
+
1402
+ word, text = self.backbone.encode_text(txt)
1403
+
1404
+ fq = self.FPN(vis, text)
1405
+
1406
+ b, c, h, w = fq.size()
1407
+ # b, 512, 14, 14
1408
+ ff = self.FGFusion(fq, word, pad_mask)
1409
+ ff = ff.reshape(b, c, h, w)
1410
+
1411
+ f2 = self.avg(ff)
1412
+ fi = image.unsqueeze(-1).unsqueeze(-1)
1413
+ fv = self.ASFF(fi, f2)
1414
+ fi = fi.squeeze(-1).squeeze(-1)
1415
+ # b, 1024
1416
+ ft = self.projT(text)
1417
+ loss1 = self.IT_loss(fi, ft)
1418
+ loss2 = self.IT_loss(fv, ft)
1419
+ loss = self.lamda1 * loss1 + self.lamda2 * loss2
1420
+
1421
+ return loss, fv, ft, fi
1422
+
1423
+ def forward(self, img, txt, stage):
1424
+
1425
+ if stage == '1st':
1426
+ '''
1427
+ img: b, 3, h, w
1428
+ word: b, words
1429
+ word_mask: b, words
1430
+ mask: b, 1, h, w
1431
+ stage: 1st or 2nd stage
1432
+ '''
1433
+ # padding mask used in decoder
1434
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1435
+
1436
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1437
+ # word: b, length, 512
1438
+ # state: b, 1024
1439
+ # image: b, 512
1440
+ vis, image = self.backbone.encode_image(img)
1441
+
1442
+ word, text = self.backbone.encode_text(txt)
1443
+
1444
+ fq = self.FPN(vis, text)
1445
+
1446
+ b, c, h, w = fq.size()
1447
+ # b, 512, 14, 14
1448
+ ff = self.FGFusion(fq, word, pad_mask)
1449
+ ff = ff.reshape(b, c, h, w)
1450
+
1451
+ f2 = self.avg(ff)
1452
+ fi = image.unsqueeze(-1).unsqueeze(-1)
1453
+ fv = self.ASFF(fi, f2)
1454
+ fi = fi.squeeze(-1).squeeze(-1)
1455
+ # b, 1024
1456
+ ft = self.projT(text)
1457
+ loss1 = self.IT_loss(fi, ft)
1458
+ loss2 = self.IT_loss(fv, ft)
1459
+ loss = self.lamda1 * loss1 + self.lamda2 * loss2
1460
+
1461
+ elif stage == '2nd':
1462
+ """
1463
+ txt: b, num, words
1464
+ img: b, 3, h, w
1465
+ """
1466
+
1467
+ # txt = b * num, word
1468
+ b, num, l = txt.shape[0], txt.shape[1], txt.shape[2]
1469
+ txt = txt.view(-1, txt.size(-1))
1470
+
1471
+ # padding mask used in decoder
1472
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1473
+
1474
+ b = img.shape[0]
1475
+ vis, image = self.backbone.encode_image(img)
1476
+ word, text = self.backbone.encode_text(txt)
1477
+
1478
+ fq = self.FPN(vis, text)
1479
+ # b, 512, 14, 14 (C4)
1480
+
1481
+ b, c, h, w = fq.size()
1482
+ # b, 512, 14, 14
1483
+ ff = self.FGFusion(fq, word, pad_mask)
1484
+ ff = ff.reshape(b, c, h, w)
1485
+
1486
+ f2 = self.avg(ff)
1487
+ fi = image.unsqueeze(-1).unsqueeze(-1)
1488
+ fi_ = fi.repeat(int(f2.shape[0] / fi.shape[0]), 1, 1, 1)
1489
+
1490
+ fv = self.ASFF(fi_, f2)
1491
+ fi = fi.squeeze(-1).squeeze(-1)
1492
+ # fi_ = fi_.squeeze(-1).squeeze(-1)
1493
+ # b, 1024
1494
+ ft = text.view(img.shape[0], int(text.shape[0] / img.shape[0]), -1)[:, 0, :]
1495
+ fv = fv.view(ft.shape[0], int(text.shape[0] / ft.shape[0]), fv.shape[1])
1496
+ loss = self.test_IET_loss(fi, fv, self.pos_samples, self.beta1, self.beta2)
1497
+
1498
+
1499
+ elif stage == 'test':
1500
+ """
1501
+ txt: b, num, words
1502
+ img: b, 3, h, w
1503
+ """
1504
+ txt = txt.permute(1, 0, 2)
1505
+
1506
+ # txt = b * num, word
1507
+ # txt = txt.view(-1, txt.size(-1))
1508
+
1509
+ # padding mask used in decoder
1510
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1511
+
1512
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1513
+ # word: b, length, 512
1514
+ # state: b, 1024
1515
+ # image: b, 512
1516
+ b = img.shape[0]
1517
+ words = []
1518
+ texts = []
1519
+ vis, image = self.backbone.encode_image(img)
1520
+ for i in range(txt.shape[0]):
1521
+ word, text = self.backbone.encode_text(txt[i])
1522
+ words.append(word)
1523
+ texts.append(text)
1524
+
1525
+ fvn = []
1526
+ # b, 512, 14, 14 (C4)
1527
+ for i in range(txt.shape[0]):
1528
+ fq = self.FPN(vis, texts[i])
1529
+
1530
+ b, c, h, w = fq.size()
1531
+ # b, 512, 14, 14
1532
+ ff = self.FGFusion(fq, words[i], pad_mask[i, :, :])
1533
+ ff = ff.reshape(b, c, h, w)
1534
+
1535
+ f2 = self.avg(ff)
1536
+ fi = image.unsqueeze(-1).unsqueeze(-1)
1537
+ fv = self.ASFF(fi, f2)
1538
+ fi = fi.squeeze(-1).squeeze(-1)
1539
+ fvn.append(fv)
1540
+
1541
+ # b, 1024
1542
+ ft = self.projT(texts[0])
1543
+ loss = self.IET_loss(fvn, ft, self.pos_samples, self.beta)
1544
+ fv = fvn
1545
+
1546
+
1547
+ else:
1548
+ print('stage should be either 1st or 2nd or test')
1549
+
1550
+
1551
+
1552
+ # labels = torch.ones(image.shape[0], image.shape[0]).to(image.device)
1553
+ # labels[:,-1] = 0
1554
+ # labels[3, :] = 0
1555
+
1556
+
1557
+ # out = self.avg(fq)
1558
+ # out = out.squeeze(-1).squeeze(-1)
1559
+ # out = self.fc(out)
1560
+
1561
+ return loss, fv, fi, ft
1562
+
1563
+
1564
+
1565
+ class CRIS(nn.Module):
1566
+ def __init__(self, cfg):
1567
+ super().__init__()
1568
+ # Vision & Text Encoder & Label Encoder
1569
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1570
+ map_location="cpu").eval()
1571
+
1572
+ self.backbone, _, _, _, _ = build_model(clip_model.state_dict(), cfg.word_len)
1573
+ self.backbone = self.backbone.float()
1574
+ self.Label_encoder = build_promptlearner(clip_model.state_dict()).float()
1575
+ self.Label_encoder.init_label_emb(cfg.label_path)
1576
+
1577
+ # Multi-Modal FPN
1578
+ self.FPN = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1579
+ # Fined-grained Fusion
1580
+ self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
1581
+ d_model=cfg.vis_dim,
1582
+ nhead=cfg.num_head,
1583
+ dim_ffn=cfg.dim_ffn,
1584
+ dropout=cfg.dropout,
1585
+ return_intermediate=cfg.intermediate)
1586
+ # adaptively aggretation
1587
+ self.ASFF = AdaptiveSpatialFeatureFusion(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1588
+ # text projector
1589
+ self.projT = Text_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1590
+ # parameter
1591
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1592
+ self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1593
+ self.share_temperature = True
1594
+ self.margin = 1
1595
+ self.eps = 1e-3
1596
+ self.ce = nn.CrossEntropyLoss()
1597
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
1598
+ self.fc = nn.Linear(512, cfg.num_classes)
1599
+
1600
+
1601
+
1602
+ def IT_loss(self, image_features, text_features):
1603
+ # b, 1024 / b, 1024
1604
+ batch = image_features.shape[0]
1605
+ # # normalized features
1606
+ image_features = image_features / image_features.norm(dim=-1,
1607
+ keepdim=True)
1608
+ text_features = text_features / text_features.norm(dim=-1,
1609
+ keepdim=True)
1610
+
1611
+ # cosine similarity as logits
1612
+ logit_scale = self.logit_scale.exp()
1613
+ logits_per_image = logit_scale * image_features @ text_features.t()
1614
+ logits_per_text = logits_per_image.t()
1615
+
1616
+ # shape = [global_batch_size, global_batch_size]
1617
+ contrastive_labels = torch.arange(batch).to(logits_per_image.device)
1618
+ contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
1619
+
1620
+
1621
+ return contrastive_loss
1622
+
1623
+ def IL_loss(self, image_features, label_features, labels):
1624
+
1625
+ # b, 1024 / K, 1024/ b, K
1626
+ positive_tagsT = torch.clamp(labels,0.,1.)
1627
+ negative_tagsT = torch.clamp(-labels,0.,1.)
1628
+ maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
1629
+
1630
+ # normalized features
1631
+
1632
+ image_features = image_features / image_features.norm(dim=-1,
1633
+ keepdim=True)
1634
+ label_features = label_features / label_features.norm(dim=-1,
1635
+ keepdim=True)
1636
+ # cosine similarity as logits
1637
+ logit_scale = self.multi_label_logit_scale.exp()
1638
+ logits_per_image = logit_scale * image_features @ label_features.t()
1639
+ # logits_per_label = logit_scale * label_features @ image_features.t()
1640
+ pos_score_matT = logits_per_image * positive_tagsT
1641
+ neg_score_matT = logits_per_image * negative_tagsT
1642
+ IW_pos3T = pos_score_matT.unsqueeze(1)
1643
+ IW_neg3T = neg_score_matT.unsqueeze(-1)
1644
+ OT = self.margin + IW_neg3T - IW_pos3T
1645
+ O_maskT = maskT * OT
1646
+ diffT = torch.clamp(O_maskT, 0)
1647
+ violationT = torch.sign(diffT).sum(1).sum(1)
1648
+ diffT = diffT.sum(1).sum(1)
1649
+ lossT = torch.mean(diffT / (violationT + self.eps))
1650
+
1651
+
1652
+
1653
+
1654
+ return lossT
1655
+
1656
+ def margin_loss(self, image_features, label_features, labels):
1657
+
1658
+ # b, 1024 / K, 1024/ b, K
1659
+
1660
+
1661
+ # normalized features
1662
+
1663
+ image_features = image_features / image_features.norm(dim=-1,
1664
+ keepdim=True)
1665
+ label_features = label_features / label_features.norm(dim=-1,
1666
+ keepdim=True)
1667
+ # cosine similarity as logits
1668
+ logit_scale = self.multi_label_logit_scale.exp()
1669
+ logits_per_image = logit_scale * image_features @ label_features.t()
1670
+ # logits_per_label = logit_scale * label_features @ image_features.t()
1671
+
1672
+ image_label_positive_pairs = logits_per_image * labels
1673
+ image_label_mean_positive = image_label_positive_pairs.sum() / labels.sum()
1674
+ image_label_negative_pairs = logits_per_image * (1 - labels)
1675
+ image_label_mean_negative = image_label_negative_pairs.sum() / (logits_per_image.numel() - labels.sum() + self.eps)
1676
+
1677
+ contrastive_loss = torch.relu(self.margin - image_label_mean_positive + image_label_mean_negative)
1678
+
1679
+ return contrastive_loss
1680
+
1681
+ def forward(self, img, txt, target=None):
1682
+ '''
1683
+ img: b, 3, h, w
1684
+ word: b, words
1685
+ word_mask: b, words
1686
+ mask: b, 1, h, w
1687
+ '''
1688
+
1689
+ # padding mask used in decoder
1690
+
1691
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
1692
+
1693
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
1694
+ # word: b, length, 512
1695
+ # state: b, 1024
1696
+ # image: b, 512
1697
+ vis, image = self.backbone.encode_image(img)
1698
+ word, text = self.backbone.encode_text(txt)
1699
+
1700
+
1701
+ fl = self.Label_encoder(image.device)
1702
+ # b, 512, 14, 14 (C4)
1703
+ fq = self.FPN(vis, text)
1704
+ b, c, h, w = fq.size()
1705
+ # b, 512, 14, 14
1706
+ ff = self.FGFusion(fq, word, pad_mask)
1707
+ # b, 512, 196
1708
+ ff = ff.reshape(b, c, h, w)
1709
+ f2 = self.avg(ff)
1710
+ # b, 1024
1711
+ f1 = image.unsqueeze(-1).unsqueeze(-1)
1712
+ fv = self.ASFF(f1, f2)
1713
+
1714
+ # b, 1024
1715
+ ft = self.projT(text)
1716
+ # labels = torch.ones(image.shape[0], image.shape[0]).to(image.device)
1717
+ # labels[:,-1] = 0
1718
+ # labels[3, :] = 0
1719
+
1720
+ loss1 = self.IT_loss(fv, ft)
1721
+ loss2 = self.IL_loss(fv, fl, target)
1722
+ loss = loss1 + loss2
1723
+ # out = self.avg(fq)
1724
+ # out = out.squeeze(-1).squeeze(-1)
1725
+ # out = self.fc(out)
1726
+
1727
+ return loss, fv, ft, fl
1728
+
1729
+ class zh_clip(nn.Module):
1730
+ def __init__(self, cfg):
1731
+ super().__init__()
1732
+ # Vision & Text Encoder
1733
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1734
+ map_location="cpu").eval()
1735
+ self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
1736
+
1737
+ self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
1738
+ self.text_lin = nn.Linear(512, 1024)
1739
+
1740
+
1741
+ # Multi-Modal FPN
1742
+ self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1743
+ # Decoder
1744
+
1745
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
1746
+ self.fc = nn.Linear(512, cfg.num_classes)
1747
+ def forward(self, img, word):
1748
+ '''
1749
+ img: b, 3, h, w
1750
+ word: b, words
1751
+ word_mask: b, words
1752
+ mask: b, 1, h, w
1753
+ '''
1754
+ # padding mask used in decoder
1755
+
1756
+
1757
+ # vis: v1 / v2 / b, 49, 1024/ b, 196, 512
1758
+ # state: b, 1024
1759
+ # feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7
1760
+ # cls: c1 / c2 / b, 1024/ b, 512
1761
+ vis, feat, cls = self.backbone.encode_image(img)
1762
+ state = self.text_encoder(word.squeeze(1)).logits
1763
+ state = self.text_lin(state)
1764
+ # b, 1024, 7, 7 (C5)
1765
+ fq = self.neck(feat, state)
1766
+
1767
+ out = self.avg(fq)
1768
+ out = out.squeeze(-1).squeeze(-1)
1769
+ out = self.fc(out)
1770
+
1771
+ return out
1772
+
1773
+ class poi_clip(nn.Module):
1774
+ def __init__(self, cfg):
1775
+ super().__init__()
1776
+ # Vision & Text Encoder
1777
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1778
+ map_location="cpu").eval()
1779
+ self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
1780
+
1781
+ self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
1782
+ self.text_lin = nn.Linear(512, 1024)
1783
+
1784
+
1785
+ # Multi-Modal FPN
1786
+ self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1787
+ # Decoder
1788
+
1789
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
1790
+ self.fc = nn.Linear(512, cfg.num_classes)
1791
+ def forward(self, img, word):
1792
+ '''
1793
+ img: b, 3, h, w
1794
+ word: b, words
1795
+ word_mask: b, words
1796
+ mask: b, 1, h, w
1797
+ '''
1798
+ # padding mask used in decoder
1799
+
1800
+
1801
+ # vis: v1 / v2 / b, 49, 1024/ b, 196, 512
1802
+ # state: b, 1024
1803
+ # feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7
1804
+ # cls: c1 / c2 / b, 1024/ b, 512
1805
+ vis, feat, cls = self.backbone.encode_image(img)
1806
+ state = self.text_encoder(word.squeeze(1)).logits
1807
+ state = self.text_lin(state)
1808
+ # b, 1024, 7, 7 (C5)
1809
+ fq = self.neck(feat, state)
1810
+
1811
+ out = self.avg(fq)
1812
+ out = out.squeeze(-1).squeeze(-1)
1813
+ out = self.fc(out)
1814
+
1815
+ return out
1816
+
1817
+ class Clip_hash_model(nn.Module):
1818
+ def __init__(self, cfg):
1819
+ super().__init__()
1820
+
1821
+ # Vision & Text Encoder
1822
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1823
+ map_location="cpu").eval()
1824
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
1825
+ # Multi-Modal FPN
1826
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1827
+
1828
+ # Decoder
1829
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
1830
+
1831
+ self.classifier = nn.Sequential(
1832
+ nn.Linear(cfg.fpn_out[1], cfg.hash_dim, bias=True),
1833
+ nn.Tanh(),
1834
+ )
1835
+
1836
+ self.classifier2 = nn.Sequential(
1837
+ nn.Linear(cfg.hash_dim, cfg.num_classes)
1838
+ )
1839
+
1840
+ # Hash Module
1841
+ self.image_module = nn.Sequential(
1842
+ nn.Linear(cfg.img_dim, cfg.hidden_dim, bias=True),
1843
+ nn.BatchNorm1d(cfg.hidden_dim),
1844
+ nn.ReLU(True),
1845
+ nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
1846
+ nn.Tanh()
1847
+ )
1848
+
1849
+ self.text_module = nn.Sequential(
1850
+ nn.Linear(cfg.txt_dim, cfg.hidden_dim, bias=True),
1851
+ nn.BatchNorm1d(cfg.hidden_dim),
1852
+ nn.ReLU(True),
1853
+ nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
1854
+ nn.Tanh()
1855
+ )
1856
+ def forward(self, img, word, mask=None):
1857
+ '''
1858
+ img: b, 3, h, w
1859
+ word: b, words
1860
+ word_mask: b, words
1861
+ '''
1862
+ pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
1863
+ # vis: C3 / C4 / C5
1864
+ # word: b, length, 512
1865
+ # state: b, 1024
1866
+ vis, image = self.backbone.encode_image(img)
1867
+ word, state = self.backbone.encode_text(word)
1868
+
1869
+ # b, 512, 26, 26 (C4)
1870
+ fq = self.neck(vis, state)
1871
+
1872
+ # out_hash: b, code_length
1873
+ # res: b, classes
1874
+ out = self.avg(fq)
1875
+ out = out.squeeze(-1).squeeze(-1)
1876
+ out_hash = self.classifier(out)
1877
+ res = self.classifier2(out_hash)
1878
+
1879
+ # img_hash: b, code_length
1880
+ # txt_hash: b, code_length
1881
+ img_hash = self.image_module(image)
1882
+ txt_hash = self.text_module(state)
1883
+
1884
+
1885
+
1886
+ return img_hash, txt_hash, out_hash, res
1887
+
1888
+ class Clip_model(nn.Module):
1889
+ def __init__(self, cfg):
1890
+ super().__init__()
1891
+
1892
+ # Vision & Text Encoder
1893
+ clip_model = torch.jit.load(cfg.clip_pretrain,
1894
+ map_location="cpu").eval()
1895
+
1896
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
1897
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
1898
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
1899
+
1900
+ def forward(self, img, word, mask=None):
1901
+ '''
1902
+ img: b, 3, h, w
1903
+ word: b, words
1904
+ word_mask: b, words
1905
+ '''
1906
+ # vis: C3 / C4 / C5
1907
+ # word: b, length, 512
1908
+ # state: b, 1024
1909
+ pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
1910
+ vis, image = self.backbone.encode_image(img)
1911
+ word, state = self.backbone.encode_text(word)
1912
+ f = self.neck(vis, state)
1913
+ out = self.avg(f)
1914
+ out = out.squeeze(-1).squeeze(-1)
1915
+ image_features = image / image.norm(dim=-1, keepdim=True)
1916
+ text_features = state / state.norm(dim=-1, keepdim=True)
1917
+
1918
+ # cosine similarity as logits
1919
+ logit_scale = self.backbone.logit_scale.exp()
1920
+ logits_per_image = logit_scale * image_features @ text_features.t()
1921
+ logits_per_text = logits_per_image.t()
1922
+
1923
+ # shape = [global_batch_size, global_batch_size]
1924
+ return logits_per_image, logits_per_text
1925
+
1926
+
1927
+ class CISEN_rsvit_hug(nn.Module, PyTorchModelHubMixin):
1928
+ def __init__(self, embed_dim, image_resolution, vision_layers, vision_width,
1929
+ vision_patch_size, context_length, txt_length, vocab_size,
1930
+ transformer_width, transformer_heads, transformer_layers, patch_size,
1931
+ output_dim, ratio, emb_dim, fpn_in, fpn_out):
1932
+ super().__init__()
1933
+ # Vision & Text Encoder & Label Encoder
1934
+ vision_heads = vision_width * 32 // 64
1935
+
1936
+ backbone = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
1937
+ vision_patch_size, context_length, txt_length, vocab_size,
1938
+ transformer_width, transformer_heads, transformer_layers)
1939
+ self.backbone = backbone.float()
1940
+ self.patch_emb = image_resolution // patch_size
1941
+
1942
+ self.FPN = ViTFPN(image_resolution, in_channels=fpn_in, out_channels=fpn_out)
1943
+
1944
+ self.ADP = Adapter(output_dim, 4)
1945
+ # parameter
1946
+ self.ratio = ratio
1947
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1948
+ self.share_temperature = True
1949
+ self.ce = nn.CrossEntropyLoss()
1950
+ self.ms_adaptor = nn.ModuleList(
1951
+ [
1952
+ nn.Sequential(
1953
+ nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
1954
+ nn.GroupNorm(32, emb_dim),
1955
+ nn.GELU(),
1956
+ nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
1957
+ ),
1958
+ nn.Sequential(
1959
+ nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
1960
+ ),
1961
+ nn.Sequential(
1962
+ nn.Identity(),
1963
+ ),
1964
+ nn.Sequential(
1965
+ nn.MaxPool2d(2),
1966
+ ),
1967
+
1968
+ ]
1969
+ )
1970
+
1971
+ self.ms_adaptor.apply(self.init_adaptor)
1972
+ def init_adaptor(self, m):
1973
+ if isinstance(m, nn.Conv2d):
1974
+ lecun_normal_(m.weight)
1975
+ if m.bias is not None:
1976
+ nn.init.constant_(m.bias, 0)
1977
+ elif isinstance(m, nn.GroupNorm):
1978
+ nn.init.constant_(m.bias, 0)
1979
+ nn.init.constant_(m.weight, 1.0)
1980
+ elif isinstance(m, nn.ConvTranspose2d):
1981
+ lecun_normal_(m.weight)
1982
+ if m.bias is not None:
1983
+ nn.init.zeros_(m.bias)
1984
+ # self.fc = nn.Linear(512, cfg.num_classes)
1985
+
1986
+ def image_encode(self, img):
1987
+ vis, image = self.backbone.encode_image(img)
1988
+
1989
+ x = self.ADP(image)
1990
+
1991
+ x = self.ratio * x + (1 - self.ratio) * image
1992
+ return x
1993
+
1994
+ def text_encode(self, txt):
1995
+
1996
+ word, text = self.backbone.encode_text(txt)
1997
+
1998
+ return text
1999
+
2000
+ def forward(self, img, txt):
2001
+ '''
2002
+ img: b, 3, h, w
2003
+ word: b, words
2004
+ word_mask: b, words
2005
+ mask: b, 1, h, w
2006
+ stage: 1st or 2nd stage
2007
+ '''
2008
+ # padding mask used in decoder
2009
+ pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
2010
+
2011
+ # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
2012
+ # word: b, length, 512
2013
+ # text: b, 1024
2014
+ # image: b, 1024
2015
+ vis, image = self.backbone.encode_image(img)
2016
+
2017
+ word, text = self.backbone.encode_text(txt)
2018
+
2019
+ x = self.ADP(image)
2020
+
2021
+ x = self.ratio * x + (1 - self.ratio) * image
2022
+ # Construct multi-scale feats
2023
+ vis_trans = []
2024
+ for i in range(len(self.ms_adaptor)):
2025
+ x_ = rearrange(
2026
+ vis[i],
2027
+ "b (h w) c -> b c h w",
2028
+ h=self.patch_emb,
2029
+ w=self.patch_emb,
2030
+ ).contiguous()
2031
+
2032
+ feats = self.ms_adaptor[i](x_)
2033
+
2034
+ vis_trans.append(feats)
2035
+
2036
+ # fq = self.FPN(vis, x_t)
2037
+ fv_t = self.FPN(vis_trans[1:], x, False)
2038
+ # fv_t = self.gap(fq_t)
2039
+
2040
+ # b, 1024
2041
+ fv = fv_t
2042
+ ft = text
2043
+ fi = x
2044
+
2045
+ return fv, fi, ft
cisen/utils/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.38 kB). View file
 
cisen/utils/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
cisen/utils/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
cisen/utils/config.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Functions for parsing args
3
+ # -----------------------------------------------------------------------------
4
+ import copy
5
+ import os
6
+ from ast import literal_eval
7
+
8
+ import yaml
9
+
10
+
11
+ class CfgNode(dict):
12
+ """
13
+ CfgNode represents an internal node in the configuration tree. It's a simple
14
+ dict-like container that allows for attribute-based access to keys.
15
+ """
16
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
17
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
18
+ init_dict = {} if init_dict is None else init_dict
19
+ key_list = [] if key_list is None else key_list
20
+ for k, v in init_dict.items():
21
+ if type(v) is dict:
22
+ # Convert dict to CfgNode
23
+ init_dict[k] = CfgNode(v, key_list=key_list + [k])
24
+ super(CfgNode, self).__init__(init_dict)
25
+
26
+ def __getattr__(self, name):
27
+ if name in self:
28
+ return self[name]
29
+ else:
30
+ raise AttributeError(name)
31
+
32
+ def __setattr__(self, name, value):
33
+ self[name] = value
34
+
35
+ def __str__(self):
36
+ def _indent(s_, num_spaces):
37
+ s = s_.split("\n")
38
+ if len(s) == 1:
39
+ return s_
40
+ first = s.pop(0)
41
+ s = [(num_spaces * " ") + line for line in s]
42
+ s = "\n".join(s)
43
+ s = first + "\n" + s
44
+ return s
45
+
46
+ r = ""
47
+ s = []
48
+ for k, v in sorted(self.items()):
49
+ seperator = "\n" if isinstance(v, CfgNode) else " "
50
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
51
+ attr_str = _indent(attr_str, 2)
52
+ s.append(attr_str)
53
+ r += "\n".join(s)
54
+ return r
55
+
56
+ def __repr__(self):
57
+ return "{}({})".format(self.__class__.__name__,
58
+ super(CfgNode, self).__repr__())
59
+
60
+
61
+ def load_cfg_from_cfg_file(file):
62
+ cfg = {}
63
+ assert os.path.isfile(file) and file.endswith('.yaml'), \
64
+ '{} is not a yaml file'.format(file)
65
+
66
+ with open(file, 'r') as f:
67
+ cfg_from_file = yaml.safe_load(f)
68
+
69
+ for key in cfg_from_file:
70
+ for k, v in cfg_from_file[key].items():
71
+ cfg[k] = v
72
+
73
+ cfg = CfgNode(cfg)
74
+ return cfg
75
+
76
+
77
+ def merge_cfg_from_list(cfg, cfg_list):
78
+ new_cfg = copy.deepcopy(cfg)
79
+ assert len(cfg_list) % 2 == 0
80
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
81
+ subkey = full_key.split('.')[-1]
82
+ assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
83
+ value = _decode_cfg_value(v)
84
+ value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
85
+ full_key)
86
+ setattr(new_cfg, subkey, value)
87
+
88
+ return new_cfg
89
+
90
+
91
+ def _decode_cfg_value(v):
92
+ """Decodes a raw config value (e.g., from a yaml config files or command
93
+ line argument) into a Python object.
94
+ """
95
+ # All remaining processing is only applied to strings
96
+ if not isinstance(v, str):
97
+ return v
98
+ # Try to interpret `v` as a:
99
+ # string, number, tuple, list, dict, boolean, or None
100
+ try:
101
+ v = literal_eval(v)
102
+ # The following two excepts allow v to pass through when it represents a
103
+ # string.
104
+ #
105
+ # Longer explanation:
106
+ # The type of v is always a string (before calling literal_eval), but
107
+ # sometimes it *represents* a string and other times a data structure, like
108
+ # a list. In the case that v represents a string, what we got back from the
109
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
110
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
111
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
112
+ # will raise a SyntaxError.
113
+ except ValueError:
114
+ pass
115
+ except SyntaxError:
116
+ pass
117
+ return v
118
+
119
+
120
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
121
+ """Checks that `replacement`, which is intended to replace `original` is of
122
+ the right type. The type is correct if it matches exactly or is one of a few
123
+ cases in which the type can be easily coerced.
124
+ """
125
+ original_type = type(original)
126
+ replacement_type = type(replacement)
127
+
128
+ # The types must match (with some exceptions)
129
+ if replacement_type == original_type:
130
+ return replacement
131
+
132
+ # Cast replacement from from_type to to_type if the replacement and original
133
+ # types match from_type and to_type
134
+ def conditional_cast(from_type, to_type):
135
+ if replacement_type == from_type and original_type == to_type:
136
+ return True, to_type(replacement)
137
+ else:
138
+ return False, None
139
+
140
+ # Conditionally casts
141
+ # list <-> tuple
142
+ casts = [(tuple, list), (list, tuple)]
143
+ # For py2: allow converting from str (bytes) to a unicode string
144
+ try:
145
+ casts.append((str, unicode)) # noqa: F821
146
+ except Exception:
147
+ pass
148
+
149
+ for (from_type, to_type) in casts:
150
+ converted, converted_value = conditional_cast(from_type, to_type)
151
+ if converted:
152
+ return converted_value
153
+
154
+ raise ValueError(
155
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
156
+ "key: {}".format(original_type, replacement_type, original,
157
+ replacement, full_key))
cisen/utils/dataset.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+ import random
4
+
5
+ import json
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+ from loguru import logger
12
+
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
+ _tokenizer = _Tokenizer()
16
+
17
+ # text_tokenize = AutoTokenizer.from_pretrained("./Taiyi-CLIP-s", model_max_length=512)
18
+ def tokenize(texts: Union[str, List[str]],
19
+ context_length: int = 77,
20
+ truncate: bool = False) -> torch.LongTensor:
21
+ """
22
+ Returns the tokenized representation of given input string(s)
23
+
24
+ Parameters
25
+ ----------
26
+ texts : Union[str, List[str]]
27
+ An input string or a list of input strings to tokenize
28
+
29
+ context_length : int
30
+ The context length to use; all CLIP models use 77 as the context length
31
+
32
+ truncate: bool
33
+ Whether to truncate the text in case its encoding is longer than the context length
34
+
35
+ Returns
36
+ -------
37
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
38
+ """
39
+ if isinstance(texts, str):
40
+ texts = [texts]
41
+
42
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
43
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
44
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
45
+ for text in texts]
46
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
47
+
48
+ for i, tokens in enumerate(all_tokens):
49
+ if len(tokens) > context_length:
50
+ if truncate:
51
+ tokens = tokens[:context_length]
52
+ tokens[-1] = eot_token
53
+ else:
54
+ raise RuntimeError(
55
+ f"Input {texts[i]} is too long for context length {context_length}"
56
+ )
57
+ result[i, :len(tokens)] = torch.tensor(tokens)
58
+
59
+ return result
60
+
61
+ def select_idxs(seq_length, n_to_select, n_from_select, seed=42):
62
+ """
63
+ Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split
64
+ selected indexes to separate arrays
65
+
66
+ Example:
67
+
68
+ seq_length = 20
69
+ n_from_select = 5
70
+ n_to_select = 2
71
+
72
+ input, range of length seq_length:
73
+ range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
74
+
75
+ sequences of length n_from_select:
76
+ sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]
77
+
78
+ selected n_to_select elements from each sequence
79
+ selected = [[0, 4], [7, 9], [13, 14], [16, 18]]
80
+
81
+ output, n_to_select lists of length seq_length / n_from_select:
82
+ output = [[0, 7, 13, 16], [4, 9, 14, 18]]
83
+
84
+ :param seq_length: length of sequence, say 10
85
+ :param n_to_select: number of elements to select
86
+ :param n_from_select: number of consequent elements
87
+ :return:
88
+ """
89
+ random.seed(seed)
90
+ idxs = [[] for _ in range(n_to_select)]
91
+ for i in range(seq_length // n_from_select):
92
+ ints = random.sample(range(n_from_select), n_to_select)
93
+ for j in range(n_to_select):
94
+ idxs[j].append(i * n_from_select + ints[j])
95
+ return idxs
96
+
97
+ def read_json(file_name, suppress_console_info=False):
98
+ """
99
+ Read JSON
100
+
101
+ :param file_name: input JSON path
102
+ :param suppress_console_info: toggle console printing
103
+ :return: dictionary from JSON
104
+ """
105
+ with open(file_name, 'r') as f:
106
+ data = json.load(f)
107
+ if not suppress_console_info:
108
+ print("Read from:", file_name)
109
+ return data
110
+
111
+ def get_image_file_names(data, suppress_console_info=False):# ok
112
+ """
113
+ Get list of image file names
114
+
115
+ :param data: original data from JSON
116
+ :param suppress_console_info: toggle console printing
117
+ :return: list of strings (file names)
118
+ """
119
+
120
+ file_names = []
121
+ for img in data['images']:
122
+ image_name = img["image_name"]
123
+ sample_id = img["sample_id"]
124
+ path_data = f'{sample_id}/{image_name}'
125
+ file_names.append(path_data)
126
+ if not suppress_console_info:
127
+ print("Total number of files:", len(file_names))
128
+ return file_names
129
+
130
+ def get_images(file_names, args):
131
+ transform = transforms.Compose([
132
+ transforms.Resize(224),
133
+ transforms.CenterCrop(224),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
136
+ ])
137
+ imgs = []
138
+ for i in range(len(file_names)):
139
+
140
+ img = np.array(transform(Image.open(os.path.join(args.imgs_folder, file_names[i]))))
141
+ imgs.append(img)
142
+
143
+ return np.array(imgs)
144
+
145
+ def get_captions(data, suppress_console_info=False):
146
+ """
147
+ Get list of formatted captions
148
+ :param data: original data from JSON
149
+ :return: list of strings (captions)
150
+ """
151
+ def format_caption(string):
152
+ return string.replace('.', '').replace(',', '').replace('!', '').replace('?', '').lower()
153
+
154
+ captions = []
155
+ augmented_captions_rb = []
156
+ augmented_captions_bt_prob = []
157
+ augmented_captions_bt_chain = []
158
+ for img in data['images']:
159
+ for sent in img['sentences']:
160
+ captions.append(format_caption(sent['raw']))
161
+ try:
162
+ augmented_captions_rb.append(format_caption(sent['aug_rb']))
163
+ except:
164
+ pass
165
+ try:
166
+ augmented_captions_bt_prob.append(format_caption(sent['aug_bt_prob']))
167
+ except:
168
+ pass
169
+ try:
170
+ augmented_captions_bt_chain.append(format_caption(sent['aug_bt_chain']))
171
+ except:
172
+ pass
173
+ if not suppress_console_info:
174
+ logger.info("Total number of captions:{}", len(captions))
175
+ logger.info("Total number of augmented captions RB:{}", len(augmented_captions_rb))
176
+ logger.info("Total number of augmented captions BT (prob):{}", len(augmented_captions_bt_prob))
177
+ logger.info("Total number of augmented captions BT (chain):{}", len(augmented_captions_bt_chain))
178
+ return captions, augmented_captions_rb, augmented_captions_bt_prob, augmented_captions_bt_chain
179
+
180
+ def get_labels(data, suppress_console_info=False):
181
+ """
182
+ Get list of labels
183
+
184
+ :param data: original data from JSON
185
+ :param suppress_console_info: toggle console printing
186
+ :return: list ints (labels)
187
+ """
188
+
189
+ labels = []
190
+ for img in data['images']:
191
+ labels.append(img["classcode"])
192
+ if not suppress_console_info:
193
+ print("Total number of labels:", len(labels))
194
+ return labels
195
+
196
+ def remove_tokens(data):
197
+ """
198
+ Removes 'tokens' key from caption record, if exists; halves the size of the file
199
+
200
+ :param data: original data
201
+ :return: data without tokens
202
+ """
203
+ for img in data['images']:
204
+ for sent in img['sentences']:
205
+ try:
206
+ sent.pop("tokens")
207
+ except:
208
+ pass
209
+ return data
210
+
211
+ def write_json(file_name, data):
212
+ """
213
+ Write dictionary to JSON file
214
+
215
+ :param file_name: output path
216
+ :param data: dictionary
217
+ :return: None
218
+ """
219
+ bn = os.path.basename(file_name)
220
+ dn = os.path.dirname(file_name)
221
+ name, ext = os.path.splitext(bn)
222
+ file_name = os.path.join(dn, name + '.json')
223
+ with open(file_name, 'w') as f:
224
+ f.write(json.dumps(data, indent='\t'))
225
+ print("Written to:", file_name)
226
+
227
+ def get_split_idxs(arr_len, args):
228
+ """
229
+ Get indexes for training, query and db subsets
230
+
231
+ :param: arr_len: array length
232
+
233
+ :return: indexes for training, query and db subsets
234
+ """
235
+ idx_all = list(range(arr_len))
236
+ idx_train, idx_eval = split_indexes(idx_all, args.dataset_train_split)
237
+ idx_query, idx_db = split_indexes(idx_eval, args.dataset_query_split)
238
+
239
+ return idx_train, idx_eval, idx_query, idx_db
240
+
241
+ def split_indexes(idx_all, split):
242
+ """
243
+ Splits list in two parts.
244
+
245
+ :param idx_all: array to split
246
+ :param split: portion to split
247
+ :return: splitted lists
248
+ """
249
+ idx_length = len(idx_all)
250
+ selection_length = int(idx_length * split)
251
+
252
+ idx_selection = sorted(random.sample(idx_all, selection_length))
253
+
254
+ idx_rest = sorted(list(set(idx_all).difference(set(idx_selection))))
255
+
256
+ return idx_selection, idx_rest
257
+
258
+ def get_caption_idxs(idx_train, idx_query, idx_db):
259
+ """
260
+ Get caption indexes.
261
+
262
+ :param: idx_train: train image (and label) indexes
263
+ :param: idx_query: query image (and label) indexes
264
+ :param: idx_db: db image (and label) indexes
265
+
266
+ :return: caption indexes for corresponding index sets
267
+ """
268
+ idx_train_cap = get_caption_idxs_from_img_idxs(idx_train, num=5)
269
+ idx_query_cap = get_caption_idxs_from_img_idxs(idx_query, num=5)
270
+ idx_db_cap = get_caption_idxs_from_img_idxs(idx_db)
271
+ return idx_train_cap, idx_query_cap, idx_db_cap
272
+
273
+ def get_caption_idxs_from_img_idxs(img_idxs, num=5):
274
+ """
275
+ Get caption indexes. There are 5 captions for each image (and label).
276
+ Say, img indexes - [0, 10, 100]
277
+ Then, caption indexes - [0, 1, 2, 3, 4, 50, 51, 52, 53, 54, 100, 501, 502, 503, 504]
278
+
279
+ :param: img_idxs: image (and label) indexes
280
+
281
+ :return: caption indexes
282
+ """
283
+ caption_idxs = []
284
+ for idx in img_idxs:
285
+ for i in range(num): # each image has 5 captions
286
+ caption_idxs.append(idx * num + i)
287
+ return caption_idxs
288
+
289
+ def split_data(images, captions, labels, captions_aug, images_aug, args):
290
+ """
291
+ Split dataset to get training, query and db subsets
292
+
293
+ :param: images: image embeddings array
294
+ :param: captions: caption embeddings array
295
+ :param: labels: labels array
296
+ :param: captions_aug: augmented caption embeddings
297
+ :param: images_aug: augmented image embeddings
298
+
299
+ :return: tuples of (images, captions, labels), each element is array
300
+ """
301
+ idx_tr, idx_q, idx_db = get_split_idxs(len(images), args)
302
+ idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db)
303
+
304
+ train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap), captions_aug[idx_tr_cap], \
305
+ images_aug[idx_tr]
306
+ query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap), captions_aug[idx_q_cap], \
307
+ images_aug[idx_q]
308
+ db = images[idx_db], captions[idx_db_cap], labels[idx_db], (idx_db, idx_db_cap), captions_aug[idx_db_cap], \
309
+ images_aug[idx_db]
310
+
311
+ return train, query, db
312
+
313
+ def select_idxs(seq_length, n_to_select, n_from_select, seed=42):
314
+ """
315
+ Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split
316
+ selected indexes to separate arrays
317
+
318
+ Example:
319
+
320
+ seq_length = 20
321
+ n_from_select = 5
322
+ n_to_select = 2
323
+
324
+ input, range of length seq_length:
325
+ range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
326
+
327
+ sequences of length n_from_select:
328
+ sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]
329
+
330
+ selected n_to_select elements from each sequence
331
+ selected = [[0, 4], [7, 9], [13, 14], [16, 18]]
332
+
333
+ output, n_to_select lists of length seq_length / n_from_select:
334
+ output = [[0, 7, 13, 16], [4, 9, 14, 18]]
335
+
336
+ :param seq_length: length of sequence, say 10
337
+ :param n_to_select: number of elements to select
338
+ :param n_from_select: number of consequent elements
339
+ :return:
340
+ """
341
+ random.seed(seed)
342
+ idxs = [[] for _ in range(n_to_select)]
343
+ for i in range(seq_length // n_from_select):
344
+ ints = random.sample(range(n_from_select), n_to_select)
345
+ for j in range(n_to_select):
346
+ idxs[j].append(i * n_from_select + ints[j])
347
+ return idxs
348
+
349
+ class AbstractDataset(torch.utils.data.Dataset):
350
+
351
+ def __init__(self, images, captions, labels, targets, idxs):
352
+
353
+ self.image_replication_factor = 1 # default value, how many times we need to replicate image
354
+
355
+ self.images = images
356
+ self.captions = captions
357
+ self.labels = labels
358
+ self.targets = targets
359
+
360
+ self.idxs = np.array(idxs[0])
361
+
362
+
363
+ def __getitem__(self, index):
364
+ return
365
+
366
+ def __len__(self):
367
+ return
368
+
369
+ class CISENDataset(torch.utils.data.Dataset):
370
+ """
371
+ Class for dataset representation.
372
+ Each image has 5 corresponding captions
373
+ Duplet dataset sample - img-txt (image and corresponding caption)
374
+ """
375
+ def __init__(self, images, captions, args):
376
+ """
377
+ Initialization.
378
+ :param images: image embeddings vector
379
+ :param captions: captions embeddings vector
380
+ :param labels: labels vector
381
+ """
382
+ super().__init__()
383
+
384
+ self.images = images
385
+ self.captions = captions
386
+ # self.targets = targets
387
+ # self.labels = labels
388
+
389
+ self.word_len = args.word_len
390
+
391
+ def __getitem__(self, index):
392
+ """
393
+ Returns a tuple (img, txt, label) - image and corresponding caption
394
+ :param index: index of sample
395
+ :return: tuple (img, txt, label)
396
+ """
397
+ return (
398
+ torch.from_numpy(self.images[index].astype('float32')),
399
+ torch.from_numpy(np.array(tokenize(self.captions[index], self.word_len).squeeze(0)).astype('int64'))
400
+ # ,torch.from_numpy(self.targets[index])
401
+ )
402
+
403
+ def __len__(self):
404
+ return len(self.images)
405
+
406
+
407
+ class DatasetDuplet(AbstractDataset):
408
+ """
409
+ Class for dataset representation.
410
+ Each image has 5 corresponding captions
411
+ Duplet dataset sample - img-txt (image and corresponding caption)
412
+ """
413
+ def __init__(self, images, captions, labels, targets, idxs, args):
414
+ """
415
+ Initialization.
416
+ :param images: image embeddings vector
417
+ :param captions: captions embeddings vector
418
+ :param labels: labels vector
419
+ """
420
+ super().__init__(images, captions, labels, targets, idxs)
421
+
422
+ self.word_len = args.word_len
423
+
424
+ def __getitem__(self, index):
425
+ """
426
+ Returns a tuple (img, txt, label) - image and corresponding caption
427
+ :param index: index of sample
428
+ :return: tuple (img, txt, label)
429
+ """
430
+ return (
431
+ index,
432
+ torch.from_numpy(self.images[index].astype('float32')),
433
+ torch.from_numpy(np.array(tokenize(self.captions[index] + self.captions[index], self.word_len).squeeze(0)).astype('int64')),
434
+ self.labels[index],
435
+ self.targets[index]
436
+ )
437
+
438
+ def __len__(self):
439
+ return len(self.images)
440
+
441
+ class ModifiedDatasetDuplet(AbstractDataset):
442
+ """
443
+ Class for dataset representation.
444
+
445
+ Each image has 5 corresponding captions
446
+
447
+ Duplet dataset sample - img-txt (image and corresponding caption)
448
+ """
449
+
450
+ def __init__(self, images, captions, labels, targets, idxs, args):
451
+ """
452
+ Initialization.
453
+
454
+ :param images: image embeddings vector
455
+ :param captions: captions embeddings vector
456
+ :param labels: labels vector
457
+ """
458
+ super().__init__(images, captions, labels, targets, idxs)
459
+
460
+
461
+ def __getitem__(self, index):
462
+ """
463
+ Returns a tuple (img, txt, label) - image and corresponding caption
464
+
465
+ :param index: index of sample
466
+ :return: tuple (img, txt, label)
467
+ """
468
+ text = text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids']
469
+ return (
470
+ index,
471
+ torch.from_numpy(self.images[index].astype('float32')),
472
+ torch.from_numpy(np.array(text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids']).astype('int64')),
473
+ self.labels[index],
474
+ self.targets[index]
475
+ )
476
+
477
+ def __len__(self):
478
+ return len(self.images)
cisen/utils/hash.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math
7
+
8
+ def init_hash(dataloader, args):
9
+ dataset_size = len(dataloader.dataset)
10
+ B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
11
+ H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
12
+ Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
13
+ Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
14
+
15
+ return B, H, Hi, Ht
16
+
17
+ def GenerateCode(model, data_loader, args):
18
+
19
+ num_data = len(data_loader.dataset)
20
+ B = np.zeros([num_data, args.hash_dim], dtype=np.float32)
21
+ Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32)
22
+ Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32)
23
+ for i, (idx, image, text, label, target) in enumerate(data_loader, 0):
24
+ image = image.cuda(non_blocking = True)
25
+ text = text.cuda(non_blocking = True)
26
+
27
+ img_hash, txt_hash, output, output_s = model(image, text)
28
+
29
+ B[idx, :] = torch.sign(output.detach().cpu()).numpy()
30
+ Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy()
31
+ Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy()
32
+
33
+ return B, Bi, Bt
34
+
35
+
36
+ def CalcSim(batch_label, train_label):
37
+ S = (batch_label.mm(train_label.t()) > 0)
38
+ return S
39
+
40
+ # loss
41
+ def Logtrick(x):
42
+
43
+ lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda()))
44
+
45
+ return lt
46
+
47
+ class NTXentLoss(nn.Module):
48
+
49
+ """
50
+ Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss).
51
+
52
+ Contains single-modal and cross-modal implementations.
53
+
54
+ """
55
+
56
+ def __init__(self, temperature=1, eps=1e-6):
57
+ super(NTXentLoss, self).__init__()
58
+ self.temperature = temperature
59
+ self.eps = eps
60
+
61
+ def forward(self, *args, type='orig'):
62
+ if type == 'cross':
63
+ return self.forward_cross_modal(*args)
64
+ if type == 'orig':
65
+ return self.forward_orig(*args)
66
+ if type == 'both':
67
+ return self.forward_orig(*args), self.forward_cross_modal(*args)
68
+ else:
69
+ raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'")
70
+
71
+ def forward_cross_modal(self, mod1, mod2):
72
+ """
73
+ Cross-modal case:
74
+
75
+ p - positive pair
76
+ n - negative pair
77
+ sim - cosine similarity
78
+
79
+ ix - image modality feature number x
80
+ tx - text modality feature number x
81
+
82
+ Cross-modal case of NTXent doesn't consider similarities inside of the same modality
83
+
84
+ Similarities matrix: exp(sim(i, y))
85
+ +--+--+--+--+--+--+--+
86
+ | |i1|i2|i3|t1|t2|t3|
87
+ Modality +--+--+--+--+--+--+--+
88
+ Features |i1|0 |0 |0 |p |n |n |
89
+ +--+ +--+ +--+--+--+--+--+--+--+
90
+ |i1| |t1| |i2|0 |0 |0 |n |p |n |
91
+ +--+ +--+ +--+--+--+--+--+--+--+
92
+ |i2| |t2| ------> |i3|0 |0 |0 |n |n |p |
93
+ +--+ +--+ +--+--+--+--+--+--+--+
94
+ |i3| |t3| |t1|p |n |n |0 |0 |0 |
95
+ +--+ +--+ +--+--+--+--+--+--+--+
96
+ |t2|n |p |n |0 |0 |0 |
97
+ +--+--+--+--+--+--+--+
98
+ |t3|n |n |p |0 |0 |0 |
99
+ +--+--+--+--+--+--+--+
100
+
101
+ :param: mod1: features of the 1st modality
102
+ :param: mod1: features of the 2nd modality
103
+ :return: NTXent loss
104
+
105
+ """
106
+ # normalize for numerical stability
107
+ mod1 = F.normalize(mod1)
108
+ mod2 = F.normalize(mod2)
109
+
110
+ out = torch.cat([mod1, mod2], dim=0)
111
+
112
+ # cov and sim: [2 * batch_size, 2 * batch_size * world_size]
113
+
114
+ cov = torch.mm(out, out.t().contiguous()) # cosine similarities matrix
115
+ sim = torch.exp(cov / self.temperature)
116
+
117
+ # mask for cross-modal case, nullifies certain regions (see docstring)
118
+ zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device)
119
+ ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device)
120
+ mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device)
121
+
122
+ sim = sim * mask
123
+
124
+ # neg: [2 * batch_size]
125
+ # negative pairs sum
126
+ neg = sim.sum(dim=1)
127
+
128
+ # Positive similarity, pos becomes [2 * batch_size]
129
+ pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature)
130
+ pos = torch.cat([pos, pos], dim=0)
131
+
132
+ loss = -torch.log(pos / (neg + self.eps)).sum()
133
+ return loss
134
+
135
+ def forward_orig(self, out_1, out_2):
136
+ """
137
+ Implementation taken from:
138
+ https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py
139
+
140
+ p - positive pair
141
+ n - negative pair
142
+ sim - cosine similarity
143
+ e - Euler's number
144
+
145
+ ix - value x of input feature vector i
146
+ tx - value x of input feature vector t
147
+
148
+ Similarities matrix: exp(sim(i, y))
149
+ +--+--+--+--+--+--+--+
150
+ | |i1|i2|i3|t1|t2|t3|
151
+ Modality +--+--+--+--+--+--+--+
152
+ Features |i1|e |n |n |p |n |n |
153
+ +--+ +--+ +--+--+--+--+--+--+--+
154
+ |i1| |t1| |i2|n |e |n |n |p |n |
155
+ +--+ +--+ +--+--+--+--+--+--+--+
156
+ |i2| |t2| ------> |i3|n |n |e |n |n |p |
157
+ +--+ +--+ +--+--+--+--+--+--+--+
158
+ |i3| |t3| |t1|p |n |n |e |n |n |
159
+ +--+ +--+ +--+--+--+--+--+--+--+
160
+ |t2|n |p |n |n |e |n |
161
+ +--+--+--+--+--+--+--+
162
+ |t3|n |n |p |n |n |e |
163
+ +--+--+--+--+--+--+--+
164
+
165
+ :param out_1: input feature vector i
166
+ :param out_2: input feature vector t
167
+ :return: NTXent loss
168
+ """
169
+ out_1 = F.normalize(out_1)
170
+ out_2 = F.normalize(out_2)
171
+
172
+ out = torch.cat([out_1, out_2], dim=0)
173
+
174
+ # cov and sim: [2 * batch_size, 2 * batch_size * world_size]
175
+ # neg: [2 * batch_size]
176
+ cov = torch.mm(out, out.t().contiguous())
177
+ sim = torch.exp(cov / self.temperature)
178
+ neg = sim.sum(dim=-1)
179
+
180
+ # from each row, subtract e^1 to remove similarity measure for x1.x1
181
+ row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
182
+ neg = torch.clamp(neg - row_sub, min=self.eps) # clamp for numerical stability
183
+
184
+ # Positive similarity, pos becomes [2 * batch_size]
185
+ o = out_1 * out_2
186
+ pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature)
187
+ pos = torch.cat([pos, pos], dim=0)
188
+
189
+ loss = -torch.log(pos / (neg + self.eps)).mean()
190
+ return loss
191
+
192
+
193
+
194
+ """
195
+
196
+ out_hash: real-value code
197
+
198
+ H: total real-value code
199
+
200
+ Bbatch: batch hash code
201
+
202
+ S: similarity
203
+
204
+ num_train: number of train
205
+
206
+ num_batch: batchsize
207
+
208
+ """
209
+
210
+ def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args):
211
+ theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2
212
+
213
+ logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \
214
+ / (num_train * num_batch)
215
+
216
+ regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch)
217
+
218
+
219
+ loss_p = - logloss + args.lamda * regterm
220
+ return logloss, regterm, loss_p
221
+
222
+ def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args):
223
+ """
224
+ Calculate NTXent Loss
225
+
226
+ :param: h_img1: batch of image hashes #1 (original)
227
+ :param: h_img2: batch of image hashes #2 (augmented)
228
+ :param: h_txt1: batch of text hashes #1 (original)
229
+ :param: h_txt2: batch of text hashes #2 (augmented)
230
+
231
+ :returns: NTXent Loss
232
+ """
233
+ loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross')
234
+ loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig')
235
+ loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig')
236
+ # loss_ntxent_intra = Criterion(out_hash, out_hash, type='orig') * args.contrastive_weights[1]
237
+
238
+ loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2]
239
+ return loss_ntxent
240
+
241
+ def Calc_total_loss(H, B, S, num_train, args):
242
+ theta = H.mm(H.t()) / 2
243
+ t1 = (theta*theta).sum() / (num_train * num_train)
244
+ logloss = (- theta * S + Logtrick(Variable(theta)).data).sum()
245
+ regterm = (H - B).pow(2).sum()
246
+ loss_p = logloss + args.lamda * regterm
247
+
248
+ return logloss, regterm, loss_p
249
+
250
+ def CalcHammingDist(B1, B2):
251
+ q = B2.shape[1]
252
+ distH = 0.5 * (q - np.dot(B1, B2.transpose()))
253
+ return distH
254
+
255
+ def CalcMap(qB, rB, queryL, retrievalL):
256
+ # qB: m, q
257
+ # rB: n, q
258
+ # queryL: {0,1}^{mxl}
259
+ # retrievalL: {0,1}^{nxl}
260
+ num_query = queryL.shape[0]
261
+ map = 0
262
+ # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
263
+
264
+ for iter in range(num_query):
265
+ # 标签匹配
266
+ gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
267
+ tsum = np.sum(gnd)
268
+ if tsum == 0:
269
+ continue
270
+ # 计算query 与 database之间的汉明距离
271
+ hamm = CalcHammingDist(qB[iter, :], rB)
272
+ # 排序
273
+ ind = np.argsort(hamm)
274
+ # 汉明距离与标签对应
275
+ gnd = gnd[ind]
276
+ count = np.linspace(1, int(tsum), int(tsum))
277
+ # 按照结果排序比对是否标签一致,并返回一致的坐标
278
+ tindex = np.asarray(np.where(gnd == 1)) + 1.0
279
+ map_ = np.mean(count / (tindex))
280
+ # print(map_)
281
+ map = map + map_
282
+ map = map / num_query
283
+ # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
284
+
285
+ return map
286
+
287
+
288
+ def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20):
289
+ # qB: {-1,+1}^{mxq}
290
+ # rB: {-1,+1}^{nxq}
291
+ # queryL: {0,1}^{mxl}
292
+ # retrievalL: {0,1}^{nxl}
293
+ num_query = queryL.shape[0]
294
+ topkmap = 0
295
+ # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
296
+ for iter in range(num_query):
297
+ gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
298
+ hamm = CalcHammingDist(qB[iter, :], rB)
299
+ ind = np.argsort(hamm)
300
+ gnd = gnd[ind]
301
+
302
+ tgnd = gnd[0:topk]
303
+ tsum = np.sum(tgnd)
304
+ if tsum == 0:
305
+ continue
306
+ count = np.linspace(1, int(tsum), int(tsum))
307
+
308
+ tindex = np.asarray(np.where(tgnd == 1)) + 1.0
309
+ topkmap_ = np.mean(count / (tindex))
310
+ # print(topkmap_)
311
+ topkmap = topkmap + topkmap_
312
+ topkmap = topkmap / num_query
313
+ # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
314
+ return topkmap
cisen/utils/misc.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from loguru import logger
6
+ import sys
7
+ import inspect
8
+ import math
9
+ import torch
10
+ import torch.distributed as dist
11
+ from collections import OrderedDict
12
+ from torch import nn
13
+
14
+ def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
15
+ """Initialize random seed."""
16
+ if seed is not None:
17
+ return seed
18
+
19
+ # Make sure all ranks share the same random seed to prevent
20
+ # some potential bugs. Please refer to
21
+ # https://github.com/open-mmlab/mmdetection/issues/6339
22
+ seed = np.random.randint(2**31)
23
+ if world_size == 1:
24
+ return seed
25
+
26
+ if rank == 0:
27
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
28
+ else:
29
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
30
+ dist.broadcast(random_num, src=0)
31
+ return random_num.item()
32
+
33
+ def set_random_seed(seed, deterministic=False):
34
+ """Set random seed."""
35
+ random.seed(seed)
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed_all(seed)
39
+ if deterministic:
40
+ torch.backends.cudnn.deterministic = True
41
+ torch.backends.cudnn.benchmark = False
42
+
43
+ def worker_init_fn(worker_id, num_workers, rank, seed):
44
+ # The seed of each worker equals to
45
+ # num_worker * rank + worker_id + user_seed
46
+ worker_seed = num_workers * rank + worker_id + seed
47
+ np.random.seed(worker_seed)
48
+ random.seed(worker_seed)
49
+
50
+ class AverageMeter(object):
51
+ """Computes and stores the average and current value"""
52
+
53
+ def __init__(self, name, fmt=":f"):
54
+ self.name = name
55
+ self.fmt = fmt
56
+ self.reset()
57
+
58
+ def reset(self):
59
+ self.val = 0
60
+ self.avg = 0
61
+ self.sum = 0
62
+ self.count = 0
63
+
64
+ def update(self, val, n=1):
65
+ self.val = val
66
+ self.sum += val * n
67
+ self.count += n
68
+ self.avg = self.sum / self.count
69
+
70
+ def __str__(self):
71
+ if self.name == "Lr":
72
+ fmtstr = "{name}={val" + self.fmt + "}"
73
+ else:
74
+ fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
75
+ return fmtstr.format(**self.__dict__)
76
+
77
+ class ProgressMeter(object):
78
+ def __init__(self, num_batches, meters, prefix=""):
79
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
80
+ self.meters = meters
81
+ self.prefix = prefix
82
+
83
+ def display(self, batch):
84
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
85
+ entries += [str(meter) for meter in self.meters]
86
+ logger.info(" ".join(entries))
87
+
88
+ def _get_batch_fmtstr(self, num_batches):
89
+ num_digits = len(str(num_batches // 1))
90
+ fmt = "{:" + str(num_digits) + "d}"
91
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
92
+
93
+ def get_caller_name(depth=0):
94
+ """
95
+ Args:
96
+ depth (int): Depth of caller conext, use 0 for caller depth.
97
+ Default value: 0.
98
+
99
+ Returns:
100
+ str: module name of the caller
101
+ """
102
+ # the following logic is a little bit faster than inspect.stack() logic
103
+ frame = inspect.currentframe().f_back
104
+ for _ in range(depth):
105
+ frame = frame.f_back
106
+
107
+ return frame.f_globals["__name__"]
108
+
109
+ class StreamToLoguru:
110
+ """
111
+ stream object that redirects writes to a logger instance.
112
+ """
113
+ def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
114
+ """
115
+ Args:
116
+ level(str): log level string of loguru. Default value: "INFO".
117
+ caller_names(tuple): caller names of redirected module.
118
+ Default value: (apex, pycocotools).
119
+ """
120
+ self.level = level
121
+ self.linebuf = ""
122
+ self.caller_names = caller_names
123
+
124
+ def write(self, buf):
125
+ full_name = get_caller_name(depth=1)
126
+ module_name = full_name.rsplit(".", maxsplit=-1)[0]
127
+ if module_name in self.caller_names:
128
+ for line in buf.rstrip().splitlines():
129
+ # use caller level log
130
+ logger.opt(depth=2).log(self.level, line.rstrip())
131
+ else:
132
+ sys.__stdout__.write(buf)
133
+
134
+ def flush(self):
135
+ pass
136
+
137
+ def redirect_sys_output(log_level="INFO"):
138
+ redirect_logger = StreamToLoguru(log_level)
139
+ sys.stderr = redirect_logger
140
+ sys.stdout = redirect_logger
141
+
142
+ def setup_logger(save_dir, filename="log.txt", mode="a"):
143
+ """setup logger for training and testing.
144
+ Args:
145
+ save_dir(str): location to save log file
146
+ distributed_rank(int): device rank when multi-gpu environment
147
+ filename (string): log save name.
148
+ mode(str): log file write mode, `append` or `override`. default is `a`.
149
+
150
+ Return:
151
+ logger instance.
152
+ """
153
+ loguru_format = (
154
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
155
+ "<level>{level: <8}</level> | "
156
+ "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
157
+
158
+ logger.remove()
159
+ save_file = os.path.join(save_dir, filename)
160
+ if mode == "o" and os.path.exists(save_file):
161
+ os.remove(save_file)
162
+ # only keep logger in rank0 process
163
+
164
+ logger.add(
165
+ sys.stderr,
166
+ format=loguru_format,
167
+ level="INFO",
168
+ enqueue=True,
169
+ )
170
+ logger.add(save_file)
171
+
172
+ # redirect stdout/stderr to loguru
173
+ redirect_sys_output("INFO")
174
+
175
+ def trainMetric(pred, label):
176
+ pred = torch.argmax(pred,dim = 1)
177
+ prec = torch.sum(pred == label)
178
+
179
+ return prec
180
+
181
+ # def compute_AP(predicted_probs, true_labels):
182
+ # num_samples, num_classes = true_labels.shape
183
+ #
184
+ # # 初始化用于存储每个类别的 AP 的列表
185
+ # aps = []
186
+ #
187
+ # for class_idx in range(num_classes):
188
+ # class_true_labels = true_labels[:, class_idx]
189
+ # class_similarity_scores = predicted_probs[:, class_idx]
190
+ #
191
+ # # 获取按相似性分数排序后的样本索引
192
+ # sorted_indices = torch.argsort(class_similarity_scores, descending=True)
193
+ #
194
+ # # 计算累积精度和召回率
195
+ # tp = 0
196
+ # fp = 0
197
+ # precision_at_rank = []
198
+ # recall_at_rank = []
199
+ #
200
+ # for rank, idx in enumerate(sorted_indices):
201
+ # if class_true_labels[idx] == 1:
202
+ # tp += 1
203
+ # else:
204
+ # fp += 1
205
+ # precision = tp / (tp + fp)
206
+ # recall = tp / torch.sum(class_true_labels)
207
+ # precision_at_rank.append(precision)
208
+ # recall_at_rank.append(recall)
209
+ #
210
+ # # 计算平均精度(AP)通过计算曲线下的面积
211
+ # precision_at_rank = torch.tensor(precision_at_rank)
212
+ # recall_at_rank = torch.tensor(recall_at_rank)
213
+ # ap = torch.trapz(precision_at_rank, recall_at_rank)
214
+ #
215
+ # aps.append(ap)
216
+ #
217
+ #
218
+ # return aps
219
+ def token_wise_similarity(rep1, rep2, mask=None, chunk_size=1024):
220
+ batch_size1, n_token1, feat_dim = rep1.shape
221
+ batch_size2, n_token2, _ = rep2.shape
222
+ num_folds = math.ceil(batch_size2 / chunk_size)
223
+ output = []
224
+ for i in range(num_folds):
225
+ rep2_seg = rep2[i * chunk_size:(i + 1) * chunk_size]
226
+ out_i = rep1.reshape(-1, feat_dim) @ rep2_seg.reshape(-1, feat_dim).T
227
+ out_i = out_i.reshape(batch_size1, n_token1, -1, n_token2).max(3)[0]
228
+ if mask is None:
229
+ out_i = out_i.mean(1)
230
+ else:
231
+ out_i = out_i.sum(1)
232
+ output.append(out_i)
233
+ output = torch.cat(output, dim=1)
234
+ if mask is not None:
235
+ output = output / mask.sum(1, keepdim=True).clamp_(min=1)
236
+ return output
237
+
238
+ def compute_acc(logits, targets, topk=5):
239
+ targets = targets.squeeze(1)
240
+ p = logits.topk(topk, 1, True, True)[1]
241
+ pred = logits.topk(topk, 1, True, True)[1]
242
+ gt = targets[pred,:]
243
+
244
+ a = gt.view(1, -1)
245
+
246
+ # b = a.expand_as(pred)
247
+ c = gt.eq(targets)
248
+ correct = pred.eq(targets.view(1, -1).expand_as(pred)).contiguous()
249
+ acc_1 = correct[:1].sum(0)
250
+ acc_k = correct[:topk].sum(0)
251
+ return acc_1, acc_k
252
+
253
+ def compute_mAP(predicted_probs, true_labels):
254
+ aps = compute_AP(predicted_probs, true_labels)
255
+ aps = [ap for ap in aps if not torch.isnan(ap)]
256
+ mAP = torch.mean(torch.tensor(aps))
257
+ return mAP
258
+
259
+ def compute_F1(predictions, labels, k_val=5):
260
+ labels = labels.squeeze(1)
261
+ idx = predictions.topk(dim=1, k=k_val)[1]
262
+ predictions.fill_(0)
263
+ predictions.scatter_(dim=1, index=idx, src=torch.ones(predictions.size(0), k_val).to(predictions.device))
264
+ mask = predictions == 1
265
+ TP = (labels[mask] == 1).sum().float()
266
+ tpfp = mask.sum().float()
267
+ tpfn = (labels == 1).sum().float()
268
+ p = TP / tpfp
269
+ r = TP/tpfn
270
+ f1 = 2*p*r/(p+r)
271
+
272
+ return f1, p, r
273
+
274
+ def compute_AP(predictions, labels):
275
+ num_class = predictions.size(1)
276
+ ap = torch.zeros(num_class).to(predictions.device)
277
+ empty_class = 0
278
+ for idx_cls in range(num_class):
279
+ prediction = predictions[:, idx_cls]
280
+ label = labels[:, idx_cls]
281
+ mask = label.abs() == 1
282
+ if (label > 0).sum() == 0:
283
+ empty_class += 1
284
+ continue
285
+ binary_label = torch.clamp(label[mask], min=0, max=1)
286
+ sorted_pred, sort_idx = prediction[mask].sort(descending=True)
287
+ sorted_label = binary_label[sort_idx]
288
+ tmp = (sorted_label == 1).float()
289
+ tp = tmp.cumsum(0)
290
+ fp = (sorted_label != 1).float().cumsum(0)
291
+ num_pos = binary_label.sum()
292
+ rec = tp/num_pos
293
+ prec = tp/(tp+fp)
294
+ ap_cls = (tmp*prec).sum()/num_pos
295
+ ap[idx_cls].copy_(ap_cls)
296
+ return ap, empty_class
297
+
298
+ def compute_ACG(predictions, labels, k_val=5):
299
+ gt = labels.squeeze(1)
300
+ idx = predictions.topk(dim=1, k=k_val)[1]
301
+ pred = gt[idx, :]
302
+ pred[pred == -1] = 0
303
+ c = labels.eq(pred) # common label
304
+ r = c.sum(-1) # similarity level
305
+ # acg
306
+ acg = c.sum(-1).sum(-1) / k_val
307
+ lg = torch.log1p(torch.arange(1, k_val+1, 1) ).to(r.device)
308
+ # dcg
309
+ dcg = (torch.pow(2, r) - 1) / lg
310
+ ir, _ = r.sort(-1, descending=True)
311
+ idcg = (torch.pow(2, ir) - 1) / lg
312
+ idcg[idcg == 0] = 1e-6
313
+ ndcg = dcg.sum(-1) / idcg.sum(-1)
314
+ # map
315
+ pos = r.clone()
316
+ pos[pos != 0] = 1
317
+ j = torch.arange(1, k_val + 1, 1).to(pos.device)
318
+ P = torch.cumsum(pos, 1) / j
319
+ Npos = torch.sum(pos, 1)
320
+ Npos[Npos == 0] = 1
321
+ AP = torch.sum(P * pos, 1)
322
+ map = torch.sum(P * pos, 1) / Npos
323
+ # wmap
324
+ acgj = torch.cumsum(r, 1) / j
325
+ wmap = torch.sum(acgj * pos, 1) / Npos
326
+
327
+
328
+
329
+ return acg, ndcg, map, wmap
330
+
331
+ def compute_mAPw(predictions, labels, k_val=5):
332
+ gt = labels.squeeze(1)
333
+ idx = predictions.topk(dim=1, k=k_val)[1]
334
+ pred = gt[idx, :]
335
+ pred[pred == -1] = 0
336
+ c = labels.eq(pred)
337
+ r = c.sum(-1)
338
+ pos = r.clone()
339
+ pos[pos != 0] = 1
340
+ P = torch.cumsum(pos) / torch.arange(1, k_val+1, 1)
341
+
342
+
343
+ def adjust_learning_rate(optimizer, epoch, args):
344
+ """Decay the learning rate with half-cycle cosine after warmup"""
345
+ if epoch < args.warmup_epochs:
346
+ lr = args.base_lr * epoch / args.warmup_epochs
347
+ else:
348
+ lr = args.min_lr + (args.base_lr - args.min_lr) * 0.5 * \
349
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
350
+ for param_group in optimizer.param_groups:
351
+ if "lr_scale" in param_group:
352
+ param_group["lr"] = lr * param_group["lr_scale"]
353
+ else:
354
+ param_group["lr"] = lr
355
+ return lr
356
+
357
+ def load_ckpt(weight_dir, model, map_location, args):
358
+ checkpoint = torch.load(weight_dir, map_location=map_location)
359
+ if args.resume:
360
+ resume_epoch = checkpoint['epoch']
361
+ else:
362
+ resume_epoch = 0
363
+ pre_weight = checkpoint['state_dict']
364
+
365
+ new_pre_weight = OrderedDict()
366
+ # pre_weight =torch.jit.load(resume)
367
+ model_dict = model.state_dict()
368
+ new_model_dict = OrderedDict()
369
+ for k, v in pre_weight.items():
370
+ new_k = k.replace('module.', '') if 'module' in k else k
371
+ # 针对batch_size=1
372
+ # new_k = new_k.replace('1','2') if 'proj.1' in new_k else new_k
373
+ new_pre_weight[new_k] = v
374
+ # for k, v in model_dict.items():
375
+ # new_k = k.replace('module.', '') if 'module' in k else k
376
+ # new_model_dict[new_k] = v
377
+ pre_weight = new_pre_weight # ["model_state"]
378
+ # pretrained_dict = {}
379
+ # t_n = 0
380
+ # v_n = 0
381
+ # for k, v in pre_weight.items():
382
+ # t_n += 1
383
+ # if k in new_model_dict:
384
+ # k = 'module.' + k if 'module' not in k else k
385
+ # v_n += 1
386
+ # pretrained_dict[k] = v
387
+ # print(k)
388
+ # os._exit()
389
+ # print(f'{v_n}/{t_n} weights have been loaded!')
390
+ model_dict.update(pre_weight)
391
+ model.load_state_dict(model_dict, strict=False)
392
+
393
+ return model, resume_epoch
394
+
395
+ def load_ckpt_fpn(weight_dir, model, map_location):
396
+
397
+ pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict']
398
+ epoch = torch.load(weight_dir, map_location=map_location)['epoch']
399
+ new_pre_weight = OrderedDict()
400
+ # pre_weight =torch.jit.load(resume)
401
+ model_dict = model.state_dict()
402
+
403
+ for k, v in pre_weight.items():
404
+ new_k = k.replace('module.', '') if 'module' in k else k
405
+ # if not (new_k.startswith('FPN') or new_k.startswith('gap')):
406
+ new_pre_weight[new_k] = v
407
+
408
+ pre_weight = new_pre_weight
409
+ # ["model_state"]
410
+ model_dict.update(pre_weight)
411
+ model.load_state_dict(model_dict, strict=True)
412
+
413
+ return model, epoch
414
+ def load_ckpt_old(weight_dir, model, map_location):
415
+
416
+ pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict']
417
+ epoch = torch.load(weight_dir, map_location=map_location)['epoch']
418
+ new_pre_weight = OrderedDict()
419
+ # pre_weight =torch.jit.load(resume)
420
+ model_dict = model.state_dict()
421
+
422
+ for k, v in pre_weight.items():
423
+ new_k = k.replace('module.', '') if 'module' in k else k
424
+ if not (new_k.startswith('FPN') or new_k.startswith('gap')):
425
+ new_pre_weight[new_k] = v
426
+
427
+ pre_weight = new_pre_weight
428
+ # ["model_state"]
429
+ model_dict.update(pre_weight)
430
+ model.load_state_dict(model_dict, strict=False)
431
+
432
+ return model, epoch
433
+
434
+ def compare_ckpt(model1, model2):
435
+ V = dict()
436
+ for k, v in model1.items():
437
+ if k.startswith('projT'):
438
+ V[k] = v
439
+
440
+ for k, v in model2.items():
441
+ if k in sorted(V.keys()):
442
+ model2[k] = V[k]
443
+
444
+ return model2
cisen/utils/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text