Zhonathon commited on
Commit
aa7fb02
1 Parent(s): 21ba5c7

update all file v1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Model/AttDes/__init__.py +4 -0
  2. Model/AttDes/__pycache__/__init__.cpython-38.pyc +0 -0
  3. Model/AttDes/dataset/data_loader.py +170 -0
  4. Model/AttDes/models/AttDes.py +17 -0
  5. Model/AttDes/models/Chinese_tokenizer.pth +3 -0
  6. Model/AttDes/models/__init__.py +1 -0
  7. Model/AttDes/models/__pycache__/__init__.cpython-38.pyc +0 -0
  8. Model/AttDes/models/language_model/bert.py +50 -0
  9. Model/AttDes/models/prefixLM.py +108 -0
  10. Model/AttDes/models/resblock.py +353 -0
  11. Model/AttDes/models/tokenizer.py +92 -0
  12. Model/AttDes/models/transformer.py +291 -0
  13. Model/AttDes/models/visual_model/Chinese_tokenizer.pth +3 -0
  14. Model/AttDes/models/visual_model/backbone.py +121 -0
  15. Model/AttDes/models/visual_model/position_encoding.py +89 -0
  16. Model/AttDes/validate_local.py +399 -0
  17. Model/AttDes/validate_local_gennerate.py +332 -0
  18. Model/CLIP/cn_clip/__init__.py +0 -0
  19. Model/CLIP/cn_clip/__pycache__/__init__.cpython-38.pyc +0 -0
  20. Model/CLIP/cn_clip/clip/__init__.py +5 -0
  21. Model/CLIP/cn_clip/clip/__pycache__/__init__.cpython-38.pyc +0 -0
  22. Model/CLIP/cn_clip/clip/__pycache__/bert_tokenizer.cpython-38.pyc +0 -0
  23. Model/CLIP/cn_clip/clip/__pycache__/utils.cpython-38.pyc +0 -0
  24. Model/CLIP/cn_clip/clip/bert_tokenizer.py +436 -0
  25. Model/CLIP/cn_clip/clip/configuration_bert.py +84 -0
  26. Model/CLIP/cn_clip/clip/model.py +504 -0
  27. Model/CLIP/cn_clip/clip/model_configs/RBT3-chinese.json +13 -0
  28. Model/CLIP/cn_clip/clip/model_configs/RN50.json +7 -0
  29. Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json +13 -0
  30. Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-large-chinese.json +13 -0
  31. Model/CLIP/cn_clip/clip/model_configs/ViT-B-16.json +7 -0
  32. Model/CLIP/cn_clip/clip/model_configs/ViT-B-32.json +7 -0
  33. Model/CLIP/cn_clip/clip/model_configs/ViT-H-14.json +8 -0
  34. Model/CLIP/cn_clip/clip/model_configs/ViT-L-14-336.json +7 -0
  35. Model/CLIP/cn_clip/clip/model_configs/ViT-L-14.json +7 -0
  36. Model/CLIP/cn_clip/clip/model_configs/for_learn.py +16 -0
  37. Model/CLIP/cn_clip/clip/modeling_bert.py +460 -0
  38. Model/CLIP/cn_clip/clip/utils.py +196 -0
  39. Model/CLIP/cn_clip/clip/vocab.txt +0 -0
  40. Model/CLIP/cn_clip/eval/__init__.py +0 -0
  41. Model/CLIP/cn_clip/eval/data.py +167 -0
  42. Model/CLIP/cn_clip/eval/evaluation.py +157 -0
  43. Model/CLIP/cn_clip/eval/evaluation_tr.py +157 -0
  44. Model/CLIP/cn_clip/eval/extract_features.py +205 -0
  45. Model/CLIP/cn_clip/eval/imagenet_zeroshot_templates.py +194 -0
  46. Model/CLIP/cn_clip/eval/make_topk_predictions.py +88 -0
  47. Model/CLIP/cn_clip/eval/make_topk_predictions_tr.py +88 -0
  48. Model/CLIP/cn_clip/eval/transform_ir_annotation_to_tr.py +36 -0
  49. Model/CLIP/cn_clip/eval/zeroshot_evaluation.py +189 -0
  50. Model/CLIP/cn_clip/preprocess/__init__.py +0 -0
Model/AttDes/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import Model.AttDes.models
2
+ import Model.AttDes.dataset
3
+
4
+
Model/AttDes/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (208 Bytes). View file
 
Model/AttDes/dataset/data_loader.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+
5
+ import cv2
6
+ import sys
7
+ import json
8
+
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ from torch import nn
12
+ import numpy as np
13
+ import pandas as pd
14
+ import os.path as osp
15
+ import scipy.io as sio
16
+ import torch.utils.data as data
17
+ from PIL import Image
18
+ import matplotlib.image as mping
19
+ import torchvision.transforms as transforms
20
+
21
+ from PIL import Image
22
+ from pytorch_pretrained_bert.tokenization import BertTokenizer
23
+
24
+
25
+ def get_data_from_csv(path):
26
+ data_csv = pd.read_csv(path, encoding='utf-8')
27
+ # print(data_csv)
28
+ pic_id_list = data_csv['pic_id'].values
29
+ seg_id_list = data_csv['seg_id'].values
30
+ object_list = data_csv['object'].values
31
+ segment_list = data_csv['segment'].values
32
+ adj_list = data_csv['adj'].values
33
+ des_list = data_csv['des'].values
34
+
35
+ return pic_id_list, seg_id_list, object_list, segment_list, adj_list, des_list
36
+
37
+ class AttDesDataset(data.Dataset):
38
+
39
+ def __init__(self, data_root, dataset_name, img_root, dataset_split='train', transform=None,
40
+ bert_model='bert-base-chinese',
41
+ des_len=256, obj_len=8, tgt_len=32
42
+ ):
43
+ self.images = []
44
+ self.data_root = data_root
45
+ self.dataset_name = dataset_name
46
+ self.transform = transform
47
+ self.img_root = img_root
48
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
49
+ self.des_len = des_len
50
+ self.obj_len = obj_len
51
+ self.tgt_len = tgt_len
52
+ assert self.transform is not None
53
+ self.pic_id_list, self.seg_id_list, self.object_list, self.segment_list, self.adj_list, self.des_list = \
54
+ get_data_from_csv(self.data_root)
55
+ self.data_csv = pd.read_csv(data_root, encoding='utf-8')
56
+ def get_data_from_csv_by_id(self, id, dict=None):
57
+ pic_id_list = self.data_csv['pic_id'].values
58
+ # {id: des_}
59
+ des_list = self.data_csv['des'].values
60
+ start_time = time.time()
61
+ for i in range(len(pic_id_list)):
62
+ if str(pic_id_list[i]) == str(id):
63
+ # print("find: str(pic_id_list[i]) == str(id)", time.time() - start_time)
64
+ return des_list[i]
65
+ return ""
66
+
67
+ def get_img_from_id(self, img_id):
68
+ img_filename = self.img_root
69
+ img_filename = img_filename + '/' + str(img_id) + '.jpg'
70
+ img = Image.open(img_filename)
71
+ if self.transform:
72
+ img = self.transform(img)
73
+ return img
74
+
75
+ def encode_text_bert(self, text):
76
+ tokens = []
77
+ tokens.append("[CLS]")
78
+ token_obj = self.tokenizer.tokenize(text)
79
+ for token in token_obj:
80
+ tokens.append(token)
81
+ tokens.append("[SEP]")
82
+ tokens = self.tokenizer.convert_tokens_to_ids(tokens)
83
+ return tokens
84
+
85
+ def get_all_from_id(self, img_id, obj_given):
86
+ img_id = str(img_id)
87
+ if img_id[0] == '#':
88
+ des = ""
89
+ else:
90
+ des = self.get_data_from_csv_by_id(img_id)
91
+ img = self.get_img_from_id(img_id)
92
+ des = self.encode_text_bert(des)
93
+ obj_given = self.encode_text_bert(obj_given)
94
+ while(len(des) < self.des_len):
95
+ des.append(100)
96
+ while(len(obj_given) < self.obj_len):
97
+ obj_given.append(0)
98
+ assert len(des) == self.des_len
99
+ return img, torch.from_numpy(np.array(des)), torch.from_numpy(np.array(obj_given))
100
+
101
+ def __getitem__(self, idx):
102
+ img_id = self.pic_id_list[idx]
103
+ img = self.get_img_from_id(img_id)
104
+
105
+ # des = self.des_list[idx].split('[,,;]')
106
+ des = re.split(',|;', str(self.des_list[idx]))
107
+ masked_des = "" # chinese
108
+ for i in range(len(des)):
109
+ if i != int(self.seg_id_list[idx]):
110
+ masked_des = masked_des + des[i] + ' '
111
+
112
+ obj = self.object_list[idx] # chinese
113
+ segment = self.segment_list[idx] # chinese
114
+ masked_des = self.encode_text_bert(masked_des)
115
+ obj = self.encode_text_bert(obj)
116
+ segment = self.encode_text_bert(segment)
117
+ while(len(masked_des) < self.des_len):
118
+ masked_des.append(100)
119
+ while(len(obj) < self.obj_len):
120
+ obj.append(0)
121
+ while(len(segment) < self.tgt_len):
122
+ segment.append(0)
123
+
124
+ assert len(masked_des) == self.des_len
125
+ assert len(obj) == self.obj_len
126
+ assert len(segment) == self.tgt_len
127
+ return img, np.array(masked_des), np.array(obj), np.array(segment), img_id
128
+
129
+ def __len__(self):
130
+ return len(self.pic_id_list)
131
+
132
+
133
+
134
+ if __name__ == '__main__':
135
+ data_root = r'E:\data\Download\fur\dataset\data_for_test1.csv'
136
+ split_root = ''
137
+ dataset_name = 'Furniture'
138
+ #
139
+ # get_data_from_csv(data_root)
140
+ # img_id = 550709
141
+ # img = get_img_from_id(img_id)
142
+ # plt.imshow(img)
143
+ # plt.show()
144
+ normalize = transforms.Normalize(mean=[0, 0, 0],
145
+ std=[1, 1, 1])
146
+ dataset = AttDesDataset(data_root, dataset_name, transform=transforms.Compose([
147
+ transforms.Resize((448,448)),
148
+ transforms.RandomHorizontalFlip(),
149
+ transforms.ToTensor(),
150
+ normalize,
151
+ ]))
152
+ img, masked_des, obj, segment = dataset.__getitem__(100)
153
+
154
+ img_show = np.zeros((len(img[0]), len(img[0][0]), 3))
155
+ img_show[:, :, 0] = img[0]
156
+ img_show[:, :, 1] = img[1]
157
+ img_show[:, :, 2] = img[2]
158
+ plt.imshow(img_show)
159
+ plt.show()
160
+ print(masked_des, len(masked_des))
161
+ print(obj, len(obj))
162
+ print(segment, len(segment))
163
+ print(dataset.__len__())
164
+
165
+ # sentence_for_test = "原木地板的厚实与白色纱幔的轻飘营造朴素和浪漫的氛围,而一张编织餐椅灵动轻巧"
166
+ # tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
167
+ # tokenizer.tokenize(sentence_for_test)
168
+ # print(tokenizer.tokenize(sentence_for_test))
169
+ # print(tokenizer.convert_tokens_to_ids(sentence_for_test))
170
+
Model/AttDes/models/AttDes.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from pytorch_pretrained_bert.modeling import BertModel
6
+
7
+
8
+ class AttDes(nn.Module):
9
+ def __init__(self, args):
10
+ super(AttDes, self).__init__()
11
+ hidden_dim = args.AD_hidden_dim
12
+
13
+
14
+
15
+
16
+
17
+
Model/AttDes/models/Chinese_tokenizer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2403030d0e018aedffec4a62d69b124c350a5b1ef03035395dbcb3593deca8dd
3
+ size 142959
Model/AttDes/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Model/AttDes/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (145 Bytes). View file
 
Model/AttDes/models/language_model/bert.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from torch import nn
11
+ from typing import Dict, List
12
+
13
+ from pytorch_pretrained_bert.modeling import BertModel
14
+ from utils.misc import NestedTensor
15
+
16
+
17
+ class BERT(nn.Module):
18
+ def __init__(self, name: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
19
+ super().__init__()
20
+ if name == 'bert-base-uncased':
21
+ self.num_channels = 768
22
+ else:
23
+ self.num_channels = 1024
24
+ self.enc_num = enc_num
25
+
26
+ self.bert = BertModel.from_pretrained(name)
27
+
28
+ if not train_bert:
29
+ for parameter in self.bert.parameters():
30
+ parameter.requires_grad_(False)
31
+
32
+ def forward(self, tensor_list: NestedTensor):
33
+
34
+ if self.enc_num > 0:
35
+ all_encoder_layers, _ = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
36
+ # use the output of the X-th transformer encoder layers
37
+ xs = all_encoder_layers[self.enc_num - 1]
38
+ else:
39
+ xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)
40
+
41
+ mask = tensor_list.mask.to(torch.bool)
42
+ mask = ~mask
43
+ out = NestedTensor(xs, mask)
44
+
45
+ return out
46
+
47
+
48
+
49
+
50
+
Model/AttDes/models/prefixLM.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ author: yulong-XJTU
3
+ '''
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import copy
8
+ from AttDes.models.transformer import Transformer, subsequent_mask, ModelOne, Model005, Model006
9
+ from axial_positional_embedding import AxialPositionalEmbedding
10
+ from AttDes.models.resblock import BottleneckBlock
11
+ from random import randint
12
+ from einops import rearrange
13
+
14
+ def clone(module,N):
15
+ '''copy the given module N times'''
16
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
17
+
18
+ class PrefixLM(nn.Module):
19
+ def __init__(
20
+ self, des_len, obj_len, tgt_len,
21
+ d_model=512,
22
+ input_resolution=224,
23
+ patch_size=16,
24
+ num_text_tokens=10000,
25
+ txt_seq_len=256,
26
+ prefix_txt_len=25,
27
+ target_txt_len=52,
28
+ max_trunc_txt_len=15,
29
+ heads=8,
30
+ enc_depth=12,
31
+ dec_depth=12,
32
+ d_ff=1024,
33
+ dropout=0.,
34
+ ):
35
+ super(PrefixLM,self).__init__()
36
+ assert input_resolution % patch_size==0 and max_trunc_txt_len<=prefix_txt_len and max_trunc_txt_len<txt_seq_len
37
+ self.ResNet = nn.Sequential(*[nn.Conv2d(in_channels=3, out_channels=64, kernel_size=patch_size, stride=patch_size, bias=True),
38
+ BottleneckBlock(in_channels=64,out_channels=256,bottleneck_channels=64,),
39
+ BottleneckBlock(in_channels=256,out_channels=d_model,bottleneck_channels=128)])
40
+ self.des_len = des_len
41
+ self.obj_len = obj_len
42
+ self.tgt_len = tgt_len
43
+ self.txt_embed = nn.Embedding(num_text_tokens, d_model, padding_idx=0)
44
+ self.txt_pos_embed = nn.Embedding(self.des_len,d_model)
45
+ image_fmap_size = input_resolution // patch_size # 448 // 16
46
+ self.img_tokens_len=image_fmap_size ** 2
47
+ # self.img_pos_embed=nn.Embedding(self.img_tokens_len,d_model)
48
+ self.img_pos_embed = AxialPositionalEmbedding(d_model, axial_shape=(image_fmap_size, image_fmap_size))
49
+ self.txt_seq_len = txt_seq_len
50
+ self.target_txt_len = target_txt_len
51
+ self.prefix_txt_len = prefix_txt_len
52
+
53
+ self.max_trunc_txt_len=max_trunc_txt_len
54
+ self.num_text_tokens = num_text_tokens
55
+ self.dim_embed=d_model
56
+ self.input_resolution=input_resolution
57
+ self.patch_size=patch_size
58
+ # self.temperature = nn.Parameter(torch.tensor(1.)) # 论文中没提到
59
+ self.transformer=Transformer(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
60
+ self.ModelOne = Model005(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
61
+ self.to_logits = nn.Sequential(
62
+ nn.LayerNorm(d_model),
63
+ nn.Linear(d_model, self.num_text_tokens)
64
+ )
65
+
66
+ def forward(self, img, des, obj, tgt, return_loss=False):
67
+ device = des.device
68
+ n = des.shape[0]
69
+ img_emed = self.ResNet(img)
70
+ img_emed = rearrange(img_emed,'b c h w -> b (h w) c')
71
+ img_emed = img_emed + self.img_pos_embed(img_emed)
72
+ del img
73
+ #add<CLS>, if you change the tokenizer, don't forget to change the token ID. another [SEP] token is added at the ending(in the tokenizer.py,please check.)
74
+ tgt = F.pad(tgt, (1, 0), value=4)
75
+ labels = tgt[:,1:]
76
+ tgt = tgt[:,:-1]
77
+ # print('des:', torch.min(des), torch.max(des))
78
+ des_embed = self.txt_embed(des)
79
+ des_embed = des_embed + self.txt_pos_embed(torch.arange(self.des_len, device=device))
80
+
81
+ obj_embed = self.txt_embed(obj)
82
+ obj_embed = obj_embed + self.txt_pos_embed(torch.arange(self.obj_len, device=device))
83
+
84
+ tgt_embed = self.txt_embed(tgt)
85
+ tgt_embed = tgt_embed + self.txt_pos_embed(torch.arange(self.tgt_len, device=device))
86
+ tgt_mask = subsequent_mask(self.tgt_len).to(device)
87
+
88
+ # baseline
89
+ # prefix = torch.cat((img_emed, des_embed, obj_embed), dim=1)
90
+ # tgt_mask = subsequent_mask(self.tgt_len).to(device)
91
+ # out = self.transformer(prefix, tgt_embed, tgt_mask=tgt_mask)
92
+
93
+ # ModelOne
94
+
95
+ out = Model005(q=obj_embed, k=img_emed, v=img_emed,
96
+ tgt_embeded=tgt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
97
+ tgt_mask=tgt_mask)
98
+
99
+ logits = self.to_logits(out)
100
+ return logits, labels
101
+ # if not return_loss:
102
+ # return logits
103
+ # # temp = self.temperature.exp()
104
+ # logits = rearrange(logits, 'b n c -> b c n')
105
+ # # logits=logits*temp #带温度参数
106
+ # loss=F.cross_entropy(logits,labels,ignore_index=0)
107
+ # return loss
108
+
Model/AttDes/models/resblock.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #代码借鉴于 https://github.com/facebookresearch/detectron2
2
+
3
+ from torch import nn
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ def c2_msra_fill(module: nn.Module) -> None:
8
+ """
9
+ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
10
+ Also initializes `module.bias` to 0.
11
+ Args:
12
+ module (torch.nn.Module): module to initialize.
13
+ """
14
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
15
+ if module.bias is not None:
16
+ # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
17
+ # torch.Tensor]`.
18
+ nn.init.constant_(module.bias, 0)
19
+
20
+ def get_norm(norm, out_channels):
21
+ """
22
+ Args:
23
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
24
+ or a callable that takes a channel number and returns
25
+ the normalization layer as a nn.Module.
26
+
27
+ Returns:
28
+ nn.Module or None: the normalization layer
29
+ """
30
+ if norm is None:
31
+ return None
32
+ if isinstance(norm, str):
33
+ if len(norm) == 0:
34
+ return None
35
+ norm = {
36
+ "BN": torch.nn.BatchNorm2d,
37
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
38
+ #"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
39
+ "FrozenBN": FrozenBatchNorm2d,
40
+ "GN": lambda channels: nn.GroupNorm(32, channels),
41
+ # for debugging:
42
+ "nnSyncBN": nn.SyncBatchNorm,
43
+ #"naiveSyncBN": NaiveSyncBatchNorm,
44
+ # expose stats_mode N as an option to caller, required for zero-len inputs
45
+ #"naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
46
+ }[norm]
47
+ return norm(out_channels)
48
+ class Conv2d(torch.nn.Conv2d):
49
+ """
50
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
51
+ """
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ """
55
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
56
+
57
+ Args:
58
+ norm (nn.Module, optional): a normalization layer
59
+ activation (callable(Tensor) -> Tensor): a callable activation function
60
+
61
+ It assumes that norm layer is used before activation.
62
+ """
63
+ norm = kwargs.pop("norm", None)
64
+ activation = kwargs.pop("activation", None)
65
+ super().__init__(*args, **kwargs)
66
+
67
+ self.norm = norm
68
+ self.activation = activation
69
+
70
+ def forward(self, x):
71
+ # torchscript does not support SyncBatchNorm yet
72
+ # https://github.com/pytorch/pytorch/issues/40507
73
+ # and we skip these codes in torchscript since:
74
+ # 1. currently we only support torchscript in evaluation mode
75
+ # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
76
+ # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
77
+ if not torch.jit.is_scripting():
78
+ if x.numel() == 0 and self.training:
79
+ # https://github.com/pytorch/pytorch/issues/12013
80
+ assert not isinstance(
81
+ self.norm, torch.nn.SyncBatchNorm
82
+ ), "SyncBatchNorm does not support empty inputs!"
83
+
84
+ x = F.conv2d(
85
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
86
+ )
87
+ if self.norm is not None:
88
+ x = self.norm(x)
89
+ if self.activation is not None:
90
+ x = self.activation(x)
91
+ return x
92
+
93
+
94
+ class FrozenBatchNorm2d(nn.Module):
95
+ """
96
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
97
+
98
+ It contains non-trainable buffers called
99
+ "weight" and "bias", "running_mean", "running_var",
100
+ initialized to perform identity transformation.
101
+
102
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
103
+ which are computed from the original four parameters of BN.
104
+ The affine transform `x * weight + bias` will perform the equivalent
105
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
106
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
107
+ will be left unchanged as identity transformation.
108
+
109
+ Other pre-trained backbone models may contain all 4 parameters.
110
+
111
+ The forward is implemented by `F.batch_norm(..., training=False)`.
112
+ """
113
+
114
+ _version = 3
115
+
116
+ def __init__(self, num_features, eps=1e-5):
117
+ super().__init__()
118
+ self.num_features = num_features
119
+ self.eps = eps
120
+ self.register_buffer("weight", torch.ones(num_features))
121
+ self.register_buffer("bias", torch.zeros(num_features))
122
+ self.register_buffer("running_mean", torch.zeros(num_features))
123
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
124
+
125
+ def forward(self, x):
126
+ if x.requires_grad:
127
+ # When gradients are needed, F.batch_norm will use extra memory
128
+ # because its backward op computes gradients for weight/bias as well.
129
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
130
+ bias = self.bias - self.running_mean * scale
131
+ scale = scale.reshape(1, -1, 1, 1)
132
+ bias = bias.reshape(1, -1, 1, 1)
133
+ out_dtype = x.dtype # may be half
134
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
135
+ else:
136
+ # When gradients are not needed, F.batch_norm is a single fused op
137
+ # and provide more optimization opportunities.
138
+ return F.batch_norm(
139
+ x,
140
+ self.running_mean,
141
+ self.running_var,
142
+ self.weight,
143
+ self.bias,
144
+ training=False,
145
+ eps=self.eps,
146
+ )
147
+
148
+ def _load_from_state_dict(
149
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
150
+ ):
151
+ version = local_metadata.get("version", None)
152
+
153
+ if version is None or version < 2:
154
+ # No running_mean/var in early versions
155
+ # This will silent the warnings
156
+ if prefix + "running_mean" not in state_dict:
157
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
158
+ if prefix + "running_var" not in state_dict:
159
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
160
+
161
+ super()._load_from_state_dict(
162
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
163
+ )
164
+
165
+ def __repr__(self):
166
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
167
+
168
+ @classmethod
169
+ def convert_frozen_batchnorm(cls, module):
170
+ """
171
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
172
+
173
+ Args:
174
+ module (torch.nn.Module):
175
+
176
+ Returns:
177
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
178
+ Otherwise, in-place convert module and return it.
179
+
180
+ Similar to convert_sync_batchnorm in
181
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
182
+ """
183
+ bn_module = nn.modules.batchnorm
184
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
185
+ res = module
186
+ if isinstance(module, bn_module):
187
+ res = cls(module.num_features)
188
+ if module.affine:
189
+ res.weight.data = module.weight.data.clone().detach()
190
+ res.bias.data = module.bias.data.clone().detach()
191
+ res.running_mean.data = module.running_mean.data
192
+ res.running_var.data = module.running_var.data
193
+ res.eps = module.eps
194
+ else:
195
+ for name, child in module.named_children():
196
+ new_child = cls.convert_frozen_batchnorm(child)
197
+ if new_child is not child:
198
+ res.add_module(name, new_child)
199
+ return res
200
+
201
+
202
+ class CNNBlockBase(nn.Module):
203
+ """
204
+ A CNN block is assumed to have input channels, output channels and a stride.
205
+ The input and output of `forward()` method must be NCHW tensors.
206
+ The method can perform arbitrary computation but must match the given
207
+ channels and stride specification.
208
+
209
+ Attribute:
210
+ in_channels (int):
211
+ out_channels (int):
212
+ stride (int):
213
+ """
214
+
215
+ def __init__(self, in_channels, out_channels, stride):
216
+ """
217
+ The `__init__` method of any subclass should also contain these arguments.
218
+
219
+ Args:
220
+ in_channels (int):
221
+ out_channels (int):
222
+ stride (int):
223
+ """
224
+ super().__init__()
225
+ self.in_channels = in_channels
226
+ self.out_channels = out_channels
227
+ self.stride = stride
228
+
229
+ def freeze(self):
230
+ """
231
+ Make this block not trainable.
232
+ This method sets all parameters to `requires_grad=False`,
233
+ and convert all BatchNorm layers to FrozenBatchNorm
234
+
235
+ Returns:
236
+ the block itself
237
+ """
238
+ for p in self.parameters():
239
+ p.requires_grad = False
240
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
241
+ return self
242
+
243
+ class BottleneckBlock(CNNBlockBase):
244
+ """
245
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
246
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
247
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ in_channels,
253
+ out_channels,
254
+ #*,
255
+ bottleneck_channels,
256
+ stride=1,
257
+ num_groups=1,
258
+ norm="BN",
259
+ stride_in_1x1=False,
260
+ dilation=1,
261
+ ):
262
+ """
263
+ Args:
264
+ bottleneck_channels (int): number of output channels for the 3x3
265
+ "bottleneck" conv layers.
266
+ num_groups (int): number of groups for the 3x3 conv layer.
267
+ norm (str or callable): normalization for all conv layers.
268
+ See :func:`layers.get_norm` for supported format.
269
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
270
+ first 1x1 convolution or the bottleneck 3x3 convolution.
271
+ dilation (int): the dilation rate of the 3x3 conv layer.
272
+ """
273
+ super().__init__(in_channels, out_channels, stride)
274
+
275
+ if in_channels != out_channels:
276
+ self.shortcut = Conv2d(
277
+ in_channels,
278
+ out_channels,
279
+ kernel_size=1,
280
+ stride=stride,
281
+ bias=False,
282
+ norm=get_norm(norm, out_channels),
283
+ )
284
+ else:
285
+ self.shortcut = None
286
+
287
+ # The original MSRA ResNet models have stride in the first 1x1 conv
288
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
289
+ # stride in the 3x3 conv
290
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
291
+
292
+ self.conv1 = Conv2d(
293
+ in_channels,
294
+ bottleneck_channels,
295
+ kernel_size=1,
296
+ stride=stride_1x1,
297
+ bias=False,
298
+ norm=get_norm(norm, bottleneck_channels),
299
+ )
300
+
301
+ self.conv2 = Conv2d(
302
+ bottleneck_channels,
303
+ bottleneck_channels,
304
+ kernel_size=3,
305
+ stride=stride_3x3,
306
+ padding=1 * dilation,
307
+ bias=False,
308
+ groups=num_groups,
309
+ dilation=dilation,
310
+ norm=get_norm(norm, bottleneck_channels),
311
+ )
312
+
313
+ self.conv3 = Conv2d(
314
+ bottleneck_channels,
315
+ out_channels,
316
+ kernel_size=1,
317
+ bias=False,
318
+ norm=get_norm(norm, out_channels),
319
+ )
320
+
321
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
322
+ if layer is not None: # shortcut can be None
323
+ c2_msra_fill(layer)
324
+
325
+ # Zero-initialize the last normalization in each residual branch,
326
+ # so that at the beginning, the residual branch starts with zeros,
327
+ # and each residual block behaves like an identity.
328
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
329
+ # "For BN layers, the learnable scaling coefficient γ is initialized
330
+ # to be 1, except for each residual block's last BN
331
+ # where γ is initialized to be 0."
332
+
333
+ # nn.init.constant_(self.conv3.norm.weight, 0)
334
+ # TODO this somehow hurts performance when training GN models from scratch.
335
+ # Add it as an option when we need to use this code to train a backbone.
336
+
337
+ def forward(self, x):
338
+ out = self.conv1(x)
339
+ out = F.relu_(out)
340
+
341
+ out = self.conv2(out)
342
+ out = F.relu_(out)
343
+
344
+ out = self.conv3(out)
345
+
346
+ if self.shortcut is not None:
347
+ shortcut = self.shortcut(x)
348
+ else:
349
+ shortcut = x
350
+
351
+ out += shortcut
352
+ out = F.relu_(out)
353
+ return out
Model/AttDes/models/tokenizer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
2
+ # to give users a quick easy start to training DALL-E without doing BPE
3
+
4
+ import torch
5
+
6
+
7
+ # from transformers import BertTokenizer
8
+
9
+ import html
10
+ import os
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+ import ftfy
14
+ import regex as re
15
+
16
+ # OpenAI simple tokenizer
17
+
18
+ @lru_cache()
19
+ def default_bpe():
20
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
21
+
22
+ @lru_cache()
23
+ def bytes_to_unicode():
24
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
25
+ cs = bs[:]
26
+ n = 0
27
+ for b in range(2 ** 8):
28
+ if b not in bs:
29
+ bs.append(b)
30
+ cs.append(2 ** 8 + n)
31
+ n += 1
32
+ cs = [chr(n) for n in cs]
33
+ return dict(zip(bs, cs))
34
+
35
+ def get_pairs(word):
36
+ pairs = set()
37
+ prev_char = word[0]
38
+ for char in word[1:]:
39
+ pairs.add((prev_char, char))
40
+ prev_char = char
41
+ return pairs
42
+
43
+ def basic_clean(text):
44
+ text = ftfy.fix_text(text)
45
+ text = html.unescape(html.unescape(text))
46
+ return text.strip()
47
+
48
+ def whitespace_clean(text):
49
+ text = re.sub(r'\s+', ' ', text)
50
+ text = text.strip()
51
+ return text
52
+
53
+
54
+ # chinese tokenizer
55
+ class ChineseTokenizer:
56
+ def __init__(self):
57
+ tokenizer = torch.load('./models/Chinese_tokenizer.pth') # BertTokenizer.from_pretrained('bert-base-chinese')
58
+ self.tokenizer = tokenizer
59
+ self.vocab_size = tokenizer.vocab_size+2
60
+
61
+ def decode(self, tokens):
62
+ if torch.is_tensor(tokens):
63
+ tokens = tokens.tolist()
64
+
65
+ tokens = [token for token in tokens if token not in (0,)]
66
+ return self.tokenizer.decode(tokens)
67
+
68
+ def encode(self, text,train=False):
69
+ t=torch.tensor(self.tokenizer.encode(text, add_special_tokens=False))
70
+ if train:
71
+ return torch.cat([t,torch.tensor([5])],dim=-1)
72
+ else:
73
+ return t
74
+ #special token: [CLS]==4,[SEP]==5, [PAD]==0,<bos>=7
75
+
76
+ def tokenize(self, texts, context_length = 77, truncate_text = False,train=True):
77
+ if isinstance(texts, str):
78
+ texts = [texts]
79
+
80
+ all_tokens = [self.encode(text,train=train) for text in texts]
81
+
82
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
83
+ for i, tokens in enumerate(all_tokens):
84
+ if len(tokens) > context_length:
85
+ if truncate_text:
86
+ tokens = tokens[:context_length]
87
+ else:
88
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
89
+ result[i, :len(tokens)] = torch.tensor(tokens)
90
+
91
+ return result
92
+
Model/AttDes/models/transformer.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import torch.nn as nn
4
+ import torch.nn .functional as F
5
+ import numpy as np
6
+ import math
7
+ #helpers
8
+ def clone(module,N):
9
+ '''copy the given module N times'''
10
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
11
+ def subsequent_mask(size):
12
+ attn_shape=(1,size,size)
13
+ subsequent_mask=np.triu(np.ones(attn_shape),k=1).astype(bool)
14
+ return torch.from_numpy(subsequent_mask)==False
15
+
16
+
17
+ def attention(query, key, value, mask=None, dropout=None):
18
+ d_k = query.size(-1)
19
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
20
+ if mask is not None:
21
+ scores = scores.masked_fill(mask == 0, -1e9)
22
+ p_attn = F.softmax(scores, dim=-1)
23
+ if dropout is not None:
24
+ p_attn = dropout(p_attn)
25
+ return torch.matmul(p_attn, value), p_attn
26
+
27
+
28
+ class MultiHeadedAttention(nn.Module):
29
+ def __init__(self, h, d_model, dropout=0.1):
30
+ super(MultiHeadedAttention, self).__init__()
31
+ assert d_model % h == 0
32
+ self.d_k = d_model // h
33
+ self.h = h
34
+ self.linears = clone(nn.Linear(d_model, d_model), 4)
35
+ self.attn = None
36
+ self.dropout = nn.Dropout(p=dropout)
37
+
38
+ def forward(self, query, key, value, mask=None):
39
+ if mask is not None:
40
+ mask = mask.unsqueeze(1)
41
+ '''print('q:',query)
42
+ print('k:',key)
43
+ print('v:',value)'''
44
+ nbatchs = query.size(0)
45
+ query, key, value = [l(x).view(nbatchs, -1, self.h, self.d_k).transpose(1, 2) \
46
+ for l, x in zip(self.linears, (query, key, value))]
47
+ x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
48
+ x = x.transpose(1, 2).contiguous().view(nbatchs, -1, self.h * self.d_k)
49
+ return self.linears[-1](x)
50
+
51
+ class Feedforward(nn.Module):
52
+ def __init__(self,d_model,d_ff,dropout=0.1):
53
+ super(Feedforward,self).__init__()
54
+ self.w_1=nn.Linear(d_model,d_ff)
55
+ self.w_2=nn.Linear(d_ff,d_model)
56
+ self.dropout=nn.Dropout(dropout)
57
+
58
+ def forward(self,x):
59
+ return self.w_2(self.dropout(F.relu((self.w_1(x)))))
60
+
61
+
62
+ class LayerNorm(nn.Module):
63
+ def __init__(self,features,eps=1e-6):
64
+ super(LayerNorm,self).__init__()
65
+ self.a_2=nn.Parameter(torch.ones(features))
66
+ self.b_2=nn.Parameter(torch.zeros(features))
67
+ self.eps=eps
68
+
69
+ def forward(self,x):
70
+ mean=x.mean(-1,keepdim=True)
71
+ std=x.std(-1,keepdim=True)
72
+ return self.a_2*(x-mean)/(std+self.eps)+self.b_2
73
+
74
+
75
+ class Generator(nn.Module):
76
+ def __init__(self,d_model,vocab):
77
+ super(Generator,self).__init__()
78
+ self.proj=nn.Linear(d_model,vocab)
79
+
80
+ def forward(self,x):
81
+ return F.log_softmax(self.proj(x),dim=-1)
82
+
83
+
84
+ # encoderLayer clone numbers times of enc_depth.
85
+ # 把encoderLayer重复enc_depth次;
86
+ class Encoder(nn.Module):
87
+ def __init__(self, layer, N):
88
+ '''N encoder layers '''
89
+ super(Encoder,self).__init__()
90
+ self.layers = clone(layer, N)
91
+ self.norm = LayerNorm(layer.size)
92
+
93
+ def forward(self,x,mask=None):
94
+ for layer in self.layers:
95
+ x = layer(x, mask)
96
+ return self.norm(x)
97
+
98
+
99
+ class SublayerConnection(nn.Module):
100
+ '''LayerNorm +subLayer+dropout+residual connection'''
101
+ def __init__(self,size,dropout):
102
+ super(SublayerConnection,self).__init__()
103
+ self.norm=LayerNorm(size)
104
+ self.dropout=nn.Dropout(dropout)
105
+
106
+ def forward(self,x,sublayer):
107
+ return x+self.dropout(sublayer(self.norm(x)))
108
+
109
+
110
+ class EncoderLayer(nn.Module):
111
+ def __init__(self,size,self_attn,feed_forward,dropout):
112
+ '''size is the embedding dimension'''
113
+ super(EncoderLayer,self).__init__()
114
+ self.self_attn = self_attn
115
+ self.feed_forward = feed_forward
116
+ self.sublayer = clone(SublayerConnection(size,dropout),2)
117
+ self.size = size
118
+
119
+ def forward(self,x,mask=None):
120
+ x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x,mask))
121
+ return self.sublayer[1](x, self.feed_forward)
122
+
123
+
124
+ class Decoder(nn.Module):
125
+ def __init__(self,layer,N):
126
+ super(Decoder,self).__init__()
127
+ self.layers = clone(layer,N)
128
+ self.norm = LayerNorm(layer.size)
129
+
130
+ def forward(self,x, memory,src_mask=None,tgt_mask=None):
131
+ for layer in self.layers:
132
+ x = layer(x, memory, src_mask, tgt_mask)
133
+ return self.norm(x)
134
+
135
+
136
+ class DecoderLayer(nn.Module):
137
+ def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
138
+ super(DecoderLayer,self).__init__()
139
+ self.size = size
140
+ self.self_attn = self_attn
141
+ self.src_attn = src_attn
142
+ self.feed_forward = feed_forward
143
+ self.sublayer = clone(SublayerConnection(size,dropout),3)
144
+
145
+ def forward(self,x,memory,src_mask=None,tgt_mask=None):
146
+ m = memory
147
+ x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
148
+ x = self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask))
149
+ return self.sublayer[2](x,self.feed_forward)
150
+
151
+
152
+ class CrossAttLayer(nn.Module):
153
+ def __init__(self,d_model,self_attn,feed_forward,dropout=0.1):
154
+ super(CrossAttLayer, self).__init__()
155
+ self.size = d_model
156
+ self.self_attn = self_attn
157
+ # self.self_attn_0 = copy.deepcopy(self_attn)
158
+ self.feed_forward = feed_forward
159
+ self.dropout = nn.Dropout(dropout)
160
+ self.sublayer = clone(SublayerConnection(d_model, dropout), 2)
161
+ # self.sublayer = clone(SublayerConnection(d_model,dropout),3) # 可以改成三层的,第一层是self_attn
162
+
163
+ def forward(self,q,k,v,src_mask=None):
164
+ # k = self.sublayer[0](k, lambda k: self.self_attn_0(k,k,k))
165
+ # q = self.sublayer[0](q, lambda q: self.self_attn_0(q,q,q))
166
+ # x = self.sublayer[1](q, lambda q: self.self_attn(q,k,k,src_mask))
167
+ # x = self.sublayer[2](x, self.feed_forward)
168
+ x = self.sublayer[0](q, lambda q: self.self_attn(q,k,k,src_mask))
169
+ x = self.sublayer[1](x, self.feed_forward)
170
+ return x
171
+
172
+
173
+ class CrossAtt(nn.Module):
174
+ def __init__(self, crossAttlayer, N=1):
175
+ super(CrossAtt, self).__init__()
176
+ self.layers = clone(crossAttlayer,N)
177
+ self.norm = LayerNorm(crossAttlayer.size)
178
+
179
+ def forward(self, q, k, v, src_mask=None):
180
+ for crossAttnLayer in self.layers:
181
+ q = crossAttnLayer(q, k, v, src_mask)
182
+ return self.norm(q)
183
+
184
+ class Transformer(nn.Module):
185
+ def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
186
+ super(Transformer,self).__init__()
187
+ c = copy.deepcopy
188
+ attn = MultiHeadedAttention(heads,d_model)
189
+ ff = Feedforward(d_model,d_ff,dropout)
190
+ self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
191
+ self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
192
+ #self.register_buffer('src_mask', src_mask, persistent=False)
193
+ #self.register_buffer('tgt_mask', tgt_mask, persistent=False)
194
+ for p in self.parameters():
195
+ if p.dim() > 1:
196
+ nn.init.xavier_uniform_(p)
197
+
198
+ def forward(self,src_embeded,tgt_embeded,src_mask=None,tgt_mask=None):
199
+ return self.decode(self.encode(src_embeded,src_mask),tgt_embeded,src_mask,tgt_mask)
200
+
201
+ def encode(self,src_embeded,src_mask=None):
202
+ return self.encoder(src_embeded,src_mask)
203
+
204
+ def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
205
+ return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
206
+
207
+
208
+ class ModelOne(nn.Module):
209
+ def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
210
+ super(ModelOne,self).__init__()
211
+ c = copy.deepcopy
212
+ attn = MultiHeadedAttention(heads,d_model)
213
+ ff = Feedforward(d_model,d_ff,dropout)
214
+ self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
215
+ self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
216
+ self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
217
+ #self.register_buffer('src_mask', src_mask, persistent=False)
218
+ #self.register_buffer('tgt_mask', tgt_mask, persistent=False)
219
+ for p in self.parameters():
220
+ if p.dim() > 1:
221
+ nn.init.xavier_uniform_(p)
222
+
223
+ def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
224
+ # x = self.CrossAtt(q, img_embed, img_embed)
225
+ # x2 = self.CrossAtt(q, des_embed, des_embed)
226
+ des_embed_self = self.CrossAtt(des_embed, des_embed, des_embed)
227
+ x3 = self.CrossAtt(img_embed, des_embed_self, des_embed_self)
228
+ # src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
229
+ src_embeded = torch.cat((x3, obj_embed), dim=1)
230
+ x = self.encode(src_embeded,src_mask)
231
+ x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
232
+ return x
233
+
234
+ def encode(self,src_embeded,src_mask=None):
235
+ return self.encoder(src_embeded,src_mask)
236
+
237
+ def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
238
+ return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
239
+
240
+ class Model005(nn.Module):
241
+ def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
242
+ super(Model005,self).__init__()
243
+ c = copy.deepcopy
244
+ attn = MultiHeadedAttention(heads,d_model)
245
+ ff = Feedforward(d_model,d_ff,dropout)
246
+ self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
247
+ self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
248
+ self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
249
+ for p in self.parameters():
250
+ if p.dim() > 1:
251
+ nn.init.xavier_uniform_(p)
252
+
253
+ def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
254
+ x = self.CrossAtt(q, img_embed, img_embed)
255
+ src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
256
+ x = self.encode(src_embeded,src_mask)
257
+ x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
258
+ return x
259
+
260
+ def encode(self,src_embeded,src_mask=None):
261
+ return self.encoder(src_embeded,src_mask)
262
+
263
+ def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
264
+ return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
265
+
266
+ class Model006(nn.Module):
267
+ def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
268
+ super(Model006,self).__init__()
269
+ c = copy.deepcopy
270
+ attn = MultiHeadedAttention(heads,d_model)
271
+ ff = Feedforward(d_model,d_ff,dropout)
272
+ self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
273
+ self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
274
+ self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
275
+ for p in self.parameters():
276
+ if p.dim() > 1:
277
+ nn.init.xavier_uniform_(p)
278
+
279
+ def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
280
+ x = self.CrossAtt(img_embed, img_embed, img_embed)
281
+ x = self.CrossAtt(obj_embed, x, x)
282
+ src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
283
+ x = self.encode(src_embeded,src_mask)
284
+ x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
285
+ return x
286
+
287
+ def encode(self,src_embeded,src_mask=None):
288
+ return self.encoder(src_embeded,src_mask)
289
+
290
+ def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
291
+ return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
Model/AttDes/models/visual_model/Chinese_tokenizer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2403030d0e018aedffec4a62d69b124c350a5b1ef03035395dbcb3593deca8dd
3
+ size 142959
Model/AttDes/models/visual_model/backbone.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+
14
+ from utils.misc import NestedTensor, is_main_process
15
+
16
+ from .position_encoding import build_position_encoding
17
+
18
+
19
+ class FrozenBatchNorm2d(torch.nn.Module):
20
+ """
21
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
22
+
23
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
24
+ without which any other models than torchvision.models.resnet[18,34,50,101]
25
+ produce nans.
26
+ """
27
+
28
+ def __init__(self, n):
29
+ super(FrozenBatchNorm2d, self).__init__()
30
+ self.register_buffer("weight", torch.ones(n))
31
+ self.register_buffer("bias", torch.zeros(n))
32
+ self.register_buffer("running_mean", torch.zeros(n))
33
+ self.register_buffer("running_var", torch.ones(n))
34
+
35
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
36
+ missing_keys, unexpected_keys, error_msgs):
37
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
38
+ if num_batches_tracked_key in state_dict:
39
+ del state_dict[num_batches_tracked_key]
40
+
41
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
42
+ state_dict, prefix, local_metadata, strict,
43
+ missing_keys, unexpected_keys, error_msgs)
44
+
45
+ def forward(self, x):
46
+ # move reshapes to the beginning
47
+ # to make it fuser-friendly
48
+ w = self.weight.reshape(1, -1, 1, 1)
49
+ b = self.bias.reshape(1, -1, 1, 1)
50
+ rv = self.running_var.reshape(1, -1, 1, 1)
51
+ rm = self.running_mean.reshape(1, -1, 1, 1)
52
+ eps = 1e-5
53
+ scale = w * (rv + eps).rsqrt()
54
+ bias = b - rm * scale
55
+ return x * scale + bias
56
+
57
+
58
+ class BackboneBase(nn.Module):
59
+
60
+ def __init__(self, name:str, backbone: nn.Module, num_channels: int, return_interm_layers: bool):
61
+ super().__init__()
62
+ for name, parameter in backbone.named_parameters():
63
+ if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
64
+ parameter.requires_grad_(False)
65
+ if return_interm_layers:
66
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
67
+ else:
68
+ return_layers = {'layer4': "0"}
69
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
70
+ self.num_channels = num_channels
71
+
72
+ def forward(self, tensor_list: NestedTensor):
73
+ xs = self.body(tensor_list.tensors)
74
+ out: Dict[str, NestedTensor] = {}
75
+ for name, x in xs.items():
76
+ m = tensor_list.mask
77
+ assert m is not None
78
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
79
+ out[name] = NestedTensor(x, mask)
80
+ return out
81
+
82
+
83
+ class Backbone(BackboneBase):
84
+ """ResNet backbone with frozen BatchNorm."""
85
+ def __init__(self, name: str,
86
+ return_interm_layers: bool,
87
+ dilation: bool):
88
+
89
+ backbone = getattr(torchvision.models, name)(
90
+ replace_stride_with_dilation=[False, False, dilation],
91
+ pretrained=False, norm_layer=FrozenBatchNorm2d)
92
+ # pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
93
+ assert name in ('resnet50', 'resnet101')
94
+ num_channels = 2048
95
+ super().__init__(name, backbone, num_channels, return_interm_layers)
96
+
97
+
98
+ class Joiner(nn.Sequential):
99
+ def __init__(self, backbone, position_embedding):
100
+ super().__init__(backbone, position_embedding)
101
+
102
+ def forward(self, tensor_list: NestedTensor):
103
+ xs = self[0](tensor_list)
104
+ out: List[NestedTensor] = []
105
+ pos = []
106
+ for name, x in xs.items():
107
+ out.append(x)
108
+ # position encoding
109
+ pos.append(self[1](x).to(x.tensors.dtype))
110
+
111
+ return out, pos
112
+
113
+
114
+ def build_backbone(args):
115
+ position_embedding = build_position_encoding(args)
116
+ # train_backbone = args.lr_detr > 0
117
+ return_interm_layers = False
118
+ backbone = Backbone(args.backbone, return_interm_layers, args.dilation)
119
+ model = Joiner(backbone, position_embedding)
120
+ model.num_channels = backbone.num_channels
121
+ return model
Model/AttDes/models/visual_model/position_encoding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the visual model.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+ from utils.misc import NestedTensor
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
18
+ super().__init__()
19
+ self.num_pos_feats = num_pos_feats
20
+ self.temperature = temperature
21
+ self.normalize = normalize
22
+ if scale is not None and normalize is False:
23
+ raise ValueError("normalize should be True if scale is passed")
24
+ if scale is None:
25
+ scale = 2 * math.pi
26
+ self.scale = scale
27
+
28
+ def forward(self, tensor_list: NestedTensor):
29
+ x = tensor_list.tensors
30
+ mask = tensor_list.mask
31
+ assert mask is not None
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48
+ return pos
49
+
50
+
51
+ class PositionEmbeddingLearned(nn.Module):
52
+ """
53
+ Absolute pos embedding, learned.
54
+ """
55
+ def __init__(self, num_pos_feats=256):
56
+ super().__init__()
57
+ self.row_embed = nn.Embedding(50, num_pos_feats)
58
+ self.col_embed = nn.Embedding(50, num_pos_feats)
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ nn.init.uniform_(self.row_embed.weight)
63
+ nn.init.uniform_(self.col_embed.weight)
64
+
65
+ def forward(self, tensor_list: NestedTensor):
66
+ x = tensor_list.tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
76
+ return pos
77
+
78
+
79
+ def build_position_encoding(args):
80
+ N_steps = args.hidden_dim // 2
81
+ if args.position_embedding in ('v2', 'sine'):
82
+ # TODO find a better way of exposing other arguments
83
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
84
+ elif args.position_embedding in ('v3', 'learned'):
85
+ position_embedding = PositionEmbeddingLearned(N_steps)
86
+ else:
87
+ raise ValueError(f"not supported {args.position_embedding}")
88
+
89
+ return position_embedding
Model/AttDes/validate_local.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import random
5
+ import time
6
+ import math
7
+ import os
8
+ import pandas as pd
9
+ import numpy as np
10
+ from pathlib import Path
11
+ import torch
12
+ from nltk.translate import bleu_score
13
+ import sys
14
+ sys.path.append(r"E:\data\streamlit\Model\AttDes")
15
+ sys.path.append(r"E:\data\streamlit\Model\CLIP")
16
+ from AttDes import dataset
17
+ from AttDes.dataset import data_loader
18
+ from torch.utils.data import DataLoader, DistributedSampler
19
+ import torchvision.transforms as transforms
20
+ import AttDes.models as models
21
+ from AttDes.models import prefixLM, tokenizer
22
+
23
+ import nltk
24
+ import jieba
25
+ # from engine import train_one_epoch, validate
26
+ #
27
+ # import utils.misc as utils
28
+ # from models import __init__
29
+ # from dataset import build_dataset
30
+ # from engine import train_one_epoch, validate_txt
31
+
32
+ from einops import rearrange
33
+ from pytorch_pretrained_bert.tokenization import BertTokenizer
34
+
35
+ def get_args_parser():
36
+ parser = argparse.ArgumentParser('Set parser', add_help=False)
37
+ parser.add_argument('--device', default='cuda')
38
+ # parser.add_argument('--gpu_id', default='0', type=str)
39
+
40
+ # Dataset parameters
41
+ parser.add_argument('--data_root', type=str, default='/hy-nas/zhanghe/data/fur/txt/data_for_test2.csv')
42
+ parser.add_argument('--dataset_name', type=str, default='Furniture')
43
+ parser.add_argument('--img_root', type=str, default='/hy-nas/zhanghe/data/fur/processed_img')
44
+ parser.add_argument('--output_dir', default='./outputs/validate', help='path where to save, empty for no saving')
45
+ parser.add_argument('--seed', default=2022, type=int)
46
+ parser.add_argument('--resume', default='', help='resume for checkpoint')
47
+ parser.add_argument('--bert_model', default='bert-base-chinese', type=str)
48
+ parser.add_argument('--des_len', default=256, type=int)
49
+ parser.add_argument('--obj_len', default=8, type=int)
50
+ parser.add_argument('--tgt_len', default=35, type=int)
51
+
52
+
53
+ # Train parameters
54
+ parser.add_argument('--lr', default=1e-4, type=float)
55
+ parser.add_argument('--batch_size', default=1, type=int)
56
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
57
+ parser.add_argument('--optimizer', default='adamw', type=str)
58
+ parser.add_argument('--lr_scheduler', default='step', type=str)
59
+ parser.add_argument('--lr_drop', default=5, type=int)
60
+ parser.add_argument('--start_epoch', default=0, type=int)
61
+ parser.add_argument('--epochs', default=1, type=int)
62
+
63
+ # Model parameters
64
+ parser.add_argument('--AD_hidden_dim', default=256, type=int)
65
+ parser.add_argument('--d_model', default=512, type=int)
66
+ # visual_model parameters
67
+ parser.add_argument('--backbone', default='resnet50', type=str,
68
+ help="Name of the convolutional backbone to use")
69
+
70
+ return parser
71
+
72
+
73
+ def main(args):
74
+ device = torch.device(args.device)
75
+
76
+ # seed = args.seed
77
+ # torch.manual_seed(seed)
78
+ # np.random.seed(seed)
79
+ # random.seed(seed)
80
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
81
+ std=[0.1673, 0.1695, 0.1705])
82
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
83
+ transforms.RandomHorizontalFlip(),
84
+ transforms.ToTensor(),
85
+ normalize,
86
+ ])
87
+ dataset_all = AttDes.dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
88
+ des_len=args.des_len,
89
+ obj_len=args.obj_len,
90
+ tgt_len=args.tgt_len,
91
+ img_root=args.img_root,
92
+ transform=the_transforms)
93
+
94
+ dataloader_val = DataLoader(dataset_all,
95
+ batch_size=args.batch_size,
96
+ shuffle=False)
97
+ print("data loaded...")
98
+
99
+ Tokenizer = tokenizer.ChineseTokenizer()
100
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
101
+ input_resolution=448,
102
+ patch_size=16,
103
+ num_text_tokens=20000,
104
+ txt_seq_len=10000,
105
+ heads=4,
106
+ enc_depth=8,
107
+ dec_depth=8,
108
+ d_ff=1024,
109
+ dropout=0.1)
110
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
111
+ model.load_state_dict(torch.load('./outputs/005/checkpoint0019.pth'))
112
+
113
+ output_dir = Path(args.output_dir)
114
+ with (output_dir / "log.txt").open("a") as f:
115
+ f.write(str(args) + "\n")
116
+
117
+ print("start validate...")
118
+ start_time = time.time()
119
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
120
+ # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2000)
121
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
122
+ for epoch in range(args.start_epoch, args.epochs):
123
+ validate_txt(args, model, dataloader_val, device, batch_size=args.batch_size)
124
+
125
+ total_time = time.time() - start_time
126
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
127
+ print('Validate time {}'.format(total_time_str))
128
+
129
+
130
+ def load_AttDes_Model(model_path, device):
131
+ parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
132
+ args = parser.parse_args()
133
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
134
+ std=[0.1673, 0.1695, 0.1705])
135
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
136
+ transforms.RandomHorizontalFlip(),
137
+ transforms.ToTensor(),
138
+ normalize,
139
+ ])
140
+ dataset_all = data_loader.AttDesDataset(args.data_root, args.dataset_name,
141
+ des_len=args.des_len,
142
+ obj_len=args.obj_len,
143
+ tgt_len=args.tgt_len,
144
+ img_root=args.img_root,
145
+ transform=the_transforms)
146
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
147
+ input_resolution=448,
148
+ patch_size=16,
149
+ num_text_tokens=20000,
150
+ txt_seq_len=10000,
151
+ heads=4,
152
+ enc_depth=8,
153
+ dec_depth=8,
154
+ d_ff=1024,
155
+ dropout=0.1)
156
+ time_1 = time.time()
157
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
158
+ model.load_state_dict(torch.load(model_path))
159
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
160
+ time_2 = time.time()
161
+ print('Load model takes {}s'.format(time_2 - time_1))
162
+ return model, dataset_all, tokenizer
163
+
164
+
165
+
166
+ def validate(img1_id, img2_id, obj, model_path):
167
+ parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
168
+ args = parser.parse_args()
169
+ device = torch.device(args.device)
170
+ #
171
+ # seed = args.seed
172
+ # torch.manual_seed(seed)
173
+ # np.random.seed(seed)
174
+ # random.seed(seed)
175
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
176
+ std=[0.1673, 0.1695, 0.1705])
177
+
178
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
179
+ transforms.RandomHorizontalFlip(),
180
+ transforms.ToTensor(),
181
+ normalize,
182
+ ])
183
+ dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
184
+ des_len=args.des_len,
185
+ obj_len=args.obj_len,
186
+ tgt_len=args.tgt_len,
187
+ img_root=args.img_root,
188
+ transform=the_transforms)
189
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
190
+ input_resolution=448,
191
+ patch_size=16,
192
+ num_text_tokens=20000,
193
+ txt_seq_len=10000,
194
+ heads=4,
195
+ enc_depth=8,
196
+ dec_depth=8,
197
+ d_ff=1024,
198
+ dropout=0.1)
199
+ time_1 = time.time()
200
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
201
+ model.load_state_dict(torch.load(model_path))
202
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
203
+ time_2 = time.time()
204
+ print('Load model takes {}s'.format(time_2 - time_1))
205
+ out_list = []
206
+
207
+ label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img1_id, obj, device, tokenizer)
208
+ out_list.append([label_txt, output1, output2, output3])
209
+ label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img2_id, obj, device, tokenizer)
210
+ out_list.append([label_txt, output1, output2, output3])
211
+ return out_list
212
+
213
+ def get_data_from_csv_by_id(path, id):
214
+ data_csv = pd.read_csv(path, encoding='utf-8')
215
+ # print(data_csv)
216
+ pic_id_list = data_csv['pic_id'].values
217
+ des_list = data_csv['des'].values
218
+
219
+ for i in range(len(pic_id_list)):
220
+ if str(pic_id_list[i]) == str(id):
221
+ return des_list[i]
222
+
223
+ return ""
224
+
225
+ def validate_one_img(model, dataset_all, img_ids, obj_given, device, tokenizer):
226
+ batch_size = len(img_ids)
227
+ start_time = time.time()
228
+ model.eval()
229
+ imgs = []
230
+ dess = []
231
+ objs = []
232
+ for i in range(len(img_ids)):
233
+ img, des, obj = dataset_all.get_all_from_id(img_ids[i], obj_given[i])
234
+ # print("get img from id time:", time.time() - start_time) # 3s
235
+ imgs.append(img)
236
+ dess.append(des)
237
+ objs.append(obj)
238
+ img_data = torch.stack(imgs).to(device)
239
+ des_data = torch.stack(dess).to(device)
240
+ obj_data = torch.stack(objs).to(device)
241
+ # print("get batch time:", time.time() - start_time) # 3s
242
+ img_emed = model.ResNet(img_data)
243
+ img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
244
+ img_emed += model.img_pos_embed(img_emed)
245
+
246
+ des_embed = model.txt_embed(des_data)
247
+ des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
248
+ obj_embed = model.txt_embed(obj_data)
249
+ obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
250
+
251
+
252
+ tgt_txt = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + 101
253
+ tgt_txt_embed = model.txt_embed(tgt_txt)
254
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
255
+
256
+ # M_005
257
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
258
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
259
+ tgt_mask=None)
260
+ logits = model.to_logits(out)[:, -1]
261
+ _, index = logits.topk(3, dim=-1)
262
+ # value: tensor([[7.3227, 7.2289, 6.4169],
263
+ # [9.6868, 7.0598, 6.3911]], device='cuda:0', grad_fn= < TopkBackward0 >)
264
+ # index: tensor([[4677, 2199, 2647],
265
+ # [4510, 3763, 2145]], device='cuda:0')
266
+ sample_1st = index[:,0]
267
+ sample_2nd = index[:,1]
268
+ sample_3rd = index[:,2]
269
+ tgt_txt0 = tgt_txt
270
+ output_list = []
271
+ # print("get 1,2,3 sample time:", time.time() - start_time) # 0.01s
272
+ for sample in [sample_1st, sample_2nd, sample_3rd]:
273
+ tgt_txt = tgt_txt0
274
+ cur_len = 1
275
+ while (cur_len < model.tgt_len and sample.max() != 102): # 102 is the id of [SEP]
276
+ tgt_txt = torch.cat((tgt_txt, sample.unsqueeze(1)), dim=-1)
277
+ tgt_txt_embed = model.txt_embed(tgt_txt)
278
+ cur_len += 1
279
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
280
+ # out = model.transformer(prefix, tgt_txt_embed)
281
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
282
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
283
+ tgt_mask=None)
284
+ logits = model.to_logits(out)[:, -1]
285
+ sample = torch.argmax(logits, dim=-1)
286
+ # print("one batch sentence token time:", time.time() - start_time) # 0.6s
287
+ output_1 = []
288
+ for i in range(batch_size):
289
+ output_txt = []
290
+ for token in tgt_txt[i].tolist():
291
+ if token > 103:
292
+ output_txt.append(token)
293
+ output_txt = tokenizer.convert_ids_to_tokens(output_txt)
294
+ output_txt = ''.join(output_txt)
295
+ output_1.append(output_txt[1:])
296
+ output_list.append(output_1)
297
+ total_time = time.time() - start_time
298
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
299
+ print('Validate time {}'.format(total_time_str))
300
+ # print(output_list)
301
+ return output_list
302
+
303
+
304
+ def generate_texts(img_id, obj, model_path):
305
+ parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
306
+ args = parser.parse_args()
307
+ device = torch.device(args.device)
308
+ # seed = args.seed
309
+ # torch.manual_seed(seed)
310
+ # np.random.seed(seed)
311
+ # random.seed(seed)
312
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
313
+ std=[0.1673, 0.1695, 0.1705])
314
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
315
+ transforms.RandomHorizontalFlip(),
316
+ transforms.ToTensor(),
317
+ normalize,
318
+ ])
319
+ dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
320
+ des_len=args.des_len,
321
+ obj_len=args.obj_len,
322
+ tgt_len=args.tgt_len,
323
+ img_root=args.img_root,
324
+ transform=the_transforms)
325
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
326
+ input_resolution=448,
327
+ patch_size=16,
328
+ num_text_tokens=20000,
329
+ txt_seq_len=10000,
330
+ heads=4,
331
+ enc_depth=8,
332
+ dec_depth=8,
333
+ d_ff=1024,
334
+ dropout=0.1)
335
+ time_1 = time.time()
336
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
337
+ model.load_state_dict(torch.load(model_path))
338
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
339
+ time_2 = time.time()
340
+ print('Load model takes {}s'.format(time_2 - time_1))
341
+
342
+ print("start generate_texts")
343
+ start_time = time.time()
344
+ end1_time = time.time()
345
+ model.eval()
346
+ img_data, des, obj_data, target, img_id, obj_given = dataset_all.get_all_from_id(img_id, obj)
347
+
348
+ img_data = img_data.unsqueeze(0).to(device)
349
+ des = des.unsqueeze(0).to(device)
350
+ obj_given = obj_given.unsqueeze(0).to(device)
351
+ label = target.unsqueeze(0).to(device)
352
+
353
+ img_emed = model.ResNet(img_data)
354
+
355
+ img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
356
+ img_emed += model.img_pos_embed(img_emed)
357
+
358
+ des_embed = model.txt_embed(des)
359
+ des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
360
+ obj_embed = model.txt_embed(obj_given)
361
+ obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
362
+ tgt_txt = torch.zeros(1, 1, dtype=torch.long, device=device) + 101
363
+ tgt_txt_embed = model.txt_embed(tgt_txt)
364
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
365
+
366
+ # M_005
367
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
368
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
369
+ tgt_mask=None)
370
+
371
+
372
+
373
+
374
+ if __name__ == '__main__':
375
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
376
+ # parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
377
+ # args = parser.parse_args()
378
+ # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
379
+ # if args.output_dir:
380
+ # Path(args.output_dir).mkdir(parents=True, exist_ok=True)
381
+ # main(args)
382
+ model_name = '005'
383
+ model_path = r'E:\data\Download\models\attribute_desciption\outputs' + '/' + model_name + '/' + 'checkpoint0019.pth'
384
+ obj = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
385
+ "卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
386
+
387
+ out = generate_texts('550695', obj, model_path)
388
+
389
+
390
+
391
+
392
+
393
+
394
+
395
+
396
+
397
+
398
+
399
+
Model/AttDes/validate_local_gennerate.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import random
5
+ import time
6
+ import math
7
+ import os
8
+
9
+ import numpy as np
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from nltk.translate import bleu_score
14
+
15
+ import dataset.data_loader
16
+ import torch.backends.cudnn as cudnn
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+ import torchvision.transforms as transforms
19
+ from models import prefixLM, tokenizer
20
+ import nltk
21
+ import jieba
22
+ # from engine import train_one_epoch, validate
23
+ #
24
+ # import utils.misc as utils
25
+ from models import __init__
26
+ # from dataset import build_dataset
27
+ # from engine import train_one_epoch, validate_txt
28
+
29
+ from einops import rearrange
30
+ from pytorch_pretrained_bert.tokenization import BertTokenizer
31
+
32
+ def get_args_parser():
33
+ parser = argparse.ArgumentParser('Set parser', add_help=False)
34
+ parser.add_argument('--device', default='cuda')
35
+ parser.add_argument('--gpu_id', default='0', type=str)
36
+
37
+ # Dataset parameters
38
+ parser.add_argument('--data_root', type=str, default=r'E:\data\Download\fur\dataset\data_for_test2.csv')
39
+ parser.add_argument('--dataset_name', type=str, default='Furniture')
40
+ parser.add_argument('--img_root', type=str, default=r'E:\data\pictures')
41
+ parser.add_argument('--output_dir', default='./outputs/validate', help='path where to save, empty for no saving')
42
+ parser.add_argument('--seed', default=2022, type=int)
43
+ parser.add_argument('--resume', default='', help='resume for checkpoint')
44
+ parser.add_argument('--bert_model', default='bert-base-chinese', type=str)
45
+ parser.add_argument('--des_len', default=256, type=int)
46
+ parser.add_argument('--obj_len', default=8, type=int)
47
+ parser.add_argument('--tgt_len', default=35, type=int)
48
+
49
+
50
+ # Train parameters
51
+ parser.add_argument('--lr', default=1e-4, type=float)
52
+ parser.add_argument('--batch_size', default=1, type=int)
53
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
54
+ parser.add_argument('--optimizer', default='adamw', type=str)
55
+ parser.add_argument('--lr_scheduler', default='step', type=str)
56
+ parser.add_argument('--lr_drop', default=5, type=int)
57
+ parser.add_argument('--start_epoch', default=0, type=int)
58
+ parser.add_argument('--epochs', default=1, type=int)
59
+
60
+ # Model parameters
61
+ parser.add_argument('--AD_hidden_dim', default=256, type=int)
62
+ parser.add_argument('--d_model', default=512, type=int)
63
+ # visual_model parameters
64
+ parser.add_argument('--backbone', default='resnet50', type=str,
65
+ help="Name of the convolutional backbone to use")
66
+
67
+ return parser
68
+
69
+
70
+ def main(args):
71
+ device = torch.device(args.device)
72
+
73
+ seed = args.seed
74
+ torch.manual_seed(seed)
75
+ np.random.seed(seed)
76
+ random.seed(seed)
77
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
78
+ std=[0.1673, 0.1695, 0.1705])
79
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
80
+ transforms.RandomHorizontalFlip(),
81
+ transforms.ToTensor(),
82
+ normalize,
83
+ ])
84
+ dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
85
+ des_len=args.des_len,
86
+ obj_len=args.obj_len,
87
+ tgt_len=args.tgt_len,
88
+ img_root=args.img_root,
89
+ transform=the_transforms)
90
+
91
+ dataloader_val = DataLoader(dataset_all,
92
+ batch_size=args.batch_size,
93
+ shuffle=False)
94
+ print("data loaded...")
95
+
96
+ Tokenizer = tokenizer.ChineseTokenizer()
97
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
98
+ input_resolution=448,
99
+ patch_size=16,
100
+ num_text_tokens=20000,
101
+ txt_seq_len=10000,
102
+ heads=4,
103
+ enc_depth=8,
104
+ dec_depth=8,
105
+ d_ff=1024,
106
+ dropout=0.1)
107
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
108
+ model.load_state_dict(torch.load('./outputs/005/checkpoint0019.pth'))
109
+
110
+ output_dir = Path(args.output_dir)
111
+ with (output_dir / "log.txt").open("a") as f:
112
+ f.write(str(args) + "\n")
113
+
114
+ print("start validate...")
115
+ start_time = time.time()
116
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
117
+ # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2000)
118
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
119
+ for epoch in range(args.start_epoch, args.epochs):
120
+ validate_txt(args, model, dataloader_val, device, batch_size=args.batch_size)
121
+
122
+ total_time = time.time() - start_time
123
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
124
+ print('Validate time {}'.format(total_time_str))
125
+
126
+
127
+ def validate(img1_id, img2_id, obj, model_path):
128
+ parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
129
+ args = parser.parse_args()
130
+ device = torch.device(args.device)
131
+
132
+ seed = args.seed
133
+ torch.manual_seed(seed)
134
+ np.random.seed(seed)
135
+ random.seed(seed)
136
+ normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
137
+ std=[0.1673, 0.1695, 0.1705])
138
+ the_transforms = transforms.Compose([transforms.Resize((448, 448)),
139
+ transforms.RandomHorizontalFlip(),
140
+ transforms.ToTensor(),
141
+ normalize,
142
+ ])
143
+ dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
144
+ des_len=args.des_len,
145
+ obj_len=args.obj_len,
146
+ tgt_len=args.tgt_len,
147
+ img_root=args.img_root,
148
+ transform=the_transforms)
149
+ PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
150
+ input_resolution=448,
151
+ patch_size=16,
152
+ num_text_tokens=20000,
153
+ txt_seq_len=10000,
154
+ heads=4,
155
+ enc_depth=8,
156
+ dec_depth=8,
157
+ d_ff=1024,
158
+ dropout=0.1)
159
+ time_1 = time.time()
160
+ model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
161
+ model.load_state_dict(torch.load(model_path))
162
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
163
+ time_2 = time.time()
164
+ print('Load model takes {}s'.format(time_2 - time_1))
165
+ out_list = []
166
+
167
+ label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img1_id, obj, device, tokenizer)
168
+ out_list.append([label_txt, output1, output2, output3])
169
+ label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img2_id, obj, device, tokenizer)
170
+ out_list.append([label_txt, output1, output2, output3])
171
+ return out_list
172
+
173
+ def validate_one_img(model, dataset_all, img_id, obj, device, tokenizer):
174
+ # print("start validate...")
175
+ start_time = time.time()
176
+ end1_time = time.time()
177
+ model.eval()
178
+ print(obj)
179
+ img_data, des, obj_data, target, img_id, obj_given = dataset_all.get_all_from_id(img_id, obj)
180
+ print(obj_given)
181
+ img_data = img_data.unsqueeze(0).to(device)
182
+ des = des.unsqueeze(0).to(device)
183
+ obj_given = obj_given.unsqueeze(0).to(device)
184
+ label = target.unsqueeze(0).to(device)
185
+
186
+ img_emed = model.ResNet(img_data)
187
+
188
+ img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
189
+ img_emed += model.img_pos_embed(img_emed)
190
+
191
+ des_embed = model.txt_embed(des)
192
+ des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
193
+ obj_embed = model.txt_embed(obj_given)
194
+ obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
195
+ tgt_txt = torch.zeros(1, 1, dtype=torch.long, device=device) + 101
196
+ tgt_txt_embed = model.txt_embed(tgt_txt)
197
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
198
+
199
+
200
+ # M_005
201
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
202
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
203
+ tgt_mask=None)
204
+ logits = model.to_logits(out)[:, -1]
205
+ sample = torch.argmax(logits, dim=-1)
206
+ value, index = logits.topk(3, dim=-1)
207
+ sample = index[0][0].unsqueeze(0)
208
+ sample_2nd = index[0][1].unsqueeze(0)
209
+ sample_3rd = index[0][2].unsqueeze(0)
210
+ tgt_txt_2nd = tgt_txt
211
+ tgt_txt_3rd = tgt_txt
212
+
213
+ cur_len = 1
214
+ while (cur_len < model.tgt_len and sample != 102): # 102 is the id of [SEP]
215
+ tgt_txt = torch.cat((tgt_txt, sample.unsqueeze(1)), dim=-1)
216
+ tgt_txt_embed = model.txt_embed(tgt_txt)
217
+ cur_len += 1
218
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
219
+ # out = model.transformer(prefix, tgt_txt_embed)
220
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
221
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
222
+ tgt_mask=None)
223
+ logits = model.to_logits(out)[:, -1]
224
+ sample = torch.argmax(logits, dim=-1)
225
+ label_txt = []
226
+ output_txt = []
227
+ obj_txt = []
228
+ for token in des[0].tolist():
229
+ if token > 103:
230
+ label_txt.append(token)
231
+ for token in tgt_txt[0].tolist():
232
+ if token > 103:
233
+ output_txt.append(token)
234
+ # for token in obj_data[0].tolist():
235
+ # if token > 103:
236
+ # obj_txt.append(token)
237
+ label_txt = tokenizer.convert_ids_to_tokens(label_txt)
238
+ label_txt = ''.join(label_txt)
239
+
240
+ # obj_txt = tokenizer.convert_ids_to_tokens(obj_txt)
241
+ output_txt = tokenizer.convert_ids_to_tokens(output_txt)
242
+ output1 = ''.join(output_txt)
243
+
244
+ # 2nd
245
+ cur_len = 1
246
+ while (cur_len < model.tgt_len and sample_2nd != 102): # 102 is the id of [SEP]
247
+ tgt_txt_2nd = torch.cat((tgt_txt_2nd, sample_2nd.unsqueeze(1)), dim=-1)
248
+ tgt_txt_embed = model.txt_embed(tgt_txt_2nd)
249
+ cur_len += 1
250
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
251
+ # out = model.transformer(prefix, tgt_txt_embed)
252
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
253
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
254
+ tgt_mask=None)
255
+ logits = model.to_logits(out)[:, -1]
256
+ # logits = logits[:, :-26]
257
+ # print(logits)
258
+ sample_2nd = torch.argmax(logits, dim=-1)
259
+
260
+ output_txt = []
261
+ for token in tgt_txt_2nd[0].tolist():
262
+ if token > 103:
263
+ output_txt.append(token)
264
+ output_txt = tokenizer.convert_ids_to_tokens(output_txt)
265
+ output2 = ''.join(output_txt)
266
+ # 3rd
267
+ cur_len = 1
268
+ while (cur_len < model.tgt_len and sample_3rd != 102): # 102 is the id of [SEP]
269
+ tgt_txt_3rd = torch.cat((tgt_txt_3rd, sample_3rd.unsqueeze(1)), dim=-1)
270
+ tgt_txt_embed = model.txt_embed(tgt_txt_3rd)
271
+ cur_len += 1
272
+ tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
273
+ # out = model.transformer(prefix, tgt_txt_embed)
274
+ out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
275
+ tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
276
+ tgt_mask=None)
277
+ logits = model.to_logits(out)[:, -1]
278
+ # logits = logits[:, :-26]
279
+ sample_3rd = torch.argmax(logits, dim=-1)
280
+
281
+ output_txt = []
282
+ for token in tgt_txt_3rd[0].tolist():
283
+ if token > 103:
284
+ output_txt.append(token)
285
+ output_txt = tokenizer.convert_ids_to_tokens(output_txt)
286
+ output3 = ''.join(output_txt)
287
+
288
+ total_time = time.time() - start_time
289
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
290
+ print(output1)
291
+ print(output2)
292
+ print(output3)
293
+ print('Validate time {}'.format(total_time_str))
294
+ return label_txt, output1, output2, output3
295
+
296
+
297
+ if __name__ == '__main__':
298
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
299
+ # parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
300
+ # args = parser.parse_args()
301
+ # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
302
+ # if args.output_dir:
303
+ # Path(args.output_dir).mkdir(parents=True, exist_ok=True)
304
+ # main(args)
305
+ model_name = '005'
306
+ model_path = r'E:\data\Download\models\attribute_desciption\outputs' + '/' + model_name + '/' + 'checkpoint0019.pth'
307
+ objs = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
308
+ "卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
309
+ for obj in objs:
310
+ print(obj)
311
+ out = validate('550695', '550567', obj, model_path)
312
+ sentences1 = out[0][0].replace(';', ',').split(',')
313
+ # gt = ""
314
+ #
315
+ # for i in sentences1:
316
+ # if obj in i:
317
+ # gt = i
318
+ # gt = " ".join(jieba.cut(gt))
319
+ # print(gt)
320
+ # for i in out[0]:
321
+ # i = " ".join(jieba.cut(i))
322
+ # print(i)
323
+ # print(gt)
324
+ # bleu = nltk.translate.bleu_score.sentence_bleu([i], gt)
325
+ # print(bleu)
326
+
327
+
328
+
329
+
330
+
331
+
332
+
Model/CLIP/cn_clip/__init__.py ADDED
File without changes
Model/CLIP/cn_clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (144 Bytes). View file
 
Model/CLIP/cn_clip/clip/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .bert_tokenizer import FullTokenizer
2
+
3
+ _tokenizer = FullTokenizer()
4
+ from .utils import load_from_name, available_models, tokenize, image_transform, load
5
+
Model/CLIP/cn_clip/clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (352 Bytes). View file
 
Model/CLIP/cn_clip/clip/__pycache__/bert_tokenizer.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
Model/CLIP/cn_clip/clip/__pycache__/utils.cpython-38.pyc ADDED
Binary file (5.99 kB). View file
 
Model/CLIP/cn_clip/clip/bert_tokenizer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Tokenization classes."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import collections
23
+ import re
24
+ import unicodedata
25
+ import six
26
+ from functools import lru_cache
27
+ import os
28
+
29
+ @lru_cache()
30
+ def default_vocab():
31
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt")
32
+
33
+ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
34
+ """Checks whether the casing config is consistent with the checkpoint name."""
35
+
36
+ # The casing has to be passed in by the user and there is no explicit check
37
+ # as to whether it matches the checkpoint. The casing information probably
38
+ # should have been stored in the bert_config.json file, but it's not, so
39
+ # we have to heuristically detect it to validate.
40
+
41
+ if not init_checkpoint:
42
+ return
43
+
44
+ m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
45
+ if m is None:
46
+ return
47
+
48
+ model_name = m.group(1)
49
+
50
+ lower_models = [
51
+ "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
52
+ "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
53
+ ]
54
+
55
+ cased_models = [
56
+ "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
57
+ "multi_cased_L-12_H-768_A-12"
58
+ ]
59
+
60
+ is_bad_config = False
61
+ if model_name in lower_models and not do_lower_case:
62
+ is_bad_config = True
63
+ actual_flag = "False"
64
+ case_name = "lowercased"
65
+ opposite_flag = "True"
66
+
67
+ if model_name in cased_models and do_lower_case:
68
+ is_bad_config = True
69
+ actual_flag = "True"
70
+ case_name = "cased"
71
+ opposite_flag = "False"
72
+
73
+ if is_bad_config:
74
+ raise ValueError(
75
+ "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
76
+ "However, `%s` seems to be a %s model, so you "
77
+ "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
78
+ "how the model was pre-training. If this error is wrong, please "
79
+ "just comment out this check." % (actual_flag, init_checkpoint,
80
+ model_name, case_name, opposite_flag))
81
+
82
+
83
+ def convert_to_unicode(text):
84
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
85
+ if six.PY3:
86
+ if isinstance(text, str):
87
+ return text
88
+ elif isinstance(text, bytes):
89
+ return text.decode("utf-8", "ignore")
90
+ else:
91
+ raise ValueError("Unsupported string type: %s" % (type(text)))
92
+ elif six.PY2:
93
+ if isinstance(text, str):
94
+ return text.decode("utf-8", "ignore")
95
+ elif isinstance(text, unicode):
96
+ return text
97
+ else:
98
+ raise ValueError("Unsupported string type: %s" % (type(text)))
99
+ else:
100
+ raise ValueError("Not running on Python2 or Python 3?")
101
+
102
+
103
+ def printable_text(text):
104
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
105
+
106
+ # These functions want `str` for both Python2 and Python3, but in one case
107
+ # it's a Unicode string and in the other it's a byte string.
108
+ if six.PY3:
109
+ if isinstance(text, str):
110
+ return text
111
+ elif isinstance(text, bytes):
112
+ return text.decode("utf-8", "ignore")
113
+ else:
114
+ raise ValueError("Unsupported string type: %s" % (type(text)))
115
+ elif six.PY2:
116
+ if isinstance(text, str):
117
+ return text
118
+ elif isinstance(text, unicode):
119
+ return text.encode("utf-8")
120
+ else:
121
+ raise ValueError("Unsupported string type: %s" % (type(text)))
122
+ else:
123
+ raise ValueError("Not running on Python2 or Python 3?")
124
+
125
+
126
+ def load_vocab(vocab_file):
127
+ """Loads a vocabulary file into a dictionary."""
128
+ vocab = collections.OrderedDict()
129
+ index = 0
130
+ with open(vocab_file, "r", encoding='utf-8') as reader:
131
+ while True:
132
+ token = convert_to_unicode(reader.readline())
133
+ if not token:
134
+ break
135
+ token = token.strip()
136
+ vocab[token] = index
137
+ index += 1
138
+ return vocab
139
+
140
+
141
+ def convert_by_vocab(vocab, items):
142
+ """Converts a sequence of [tokens|ids] using the vocab."""
143
+ output = []
144
+ for item in items:
145
+ output.append(vocab[item])
146
+ return output
147
+
148
+
149
+ def convert_tokens_to_ids(vocab, tokens):
150
+ return convert_by_vocab(vocab, tokens)
151
+
152
+
153
+ def convert_ids_to_tokens(inv_vocab, ids):
154
+ return convert_by_vocab(inv_vocab, ids)
155
+
156
+
157
+ def whitespace_tokenize(text):
158
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
159
+ text = text.strip()
160
+ if not text:
161
+ return []
162
+ tokens = text.split()
163
+ return tokens
164
+
165
+
166
+ class FullTokenizer(object):
167
+ """Runs end-to-end tokenziation."""
168
+
169
+ def __init__(self, vocab_file=default_vocab(), do_lower_case=True):
170
+ self.vocab = load_vocab(vocab_file)
171
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
172
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
173
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
174
+
175
+ def tokenize(self, text):
176
+ split_tokens = []
177
+ for token in self.basic_tokenizer.tokenize(text):
178
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
179
+ split_tokens.append(sub_token)
180
+
181
+ return split_tokens
182
+
183
+ def convert_tokens_to_ids(self, tokens):
184
+ return convert_by_vocab(self.vocab, tokens)
185
+
186
+ def convert_ids_to_tokens(self, ids):
187
+ return convert_by_vocab(self.inv_vocab, ids)
188
+
189
+ @staticmethod
190
+ def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
191
+ """ Converts a sequence of tokens (string) in a single string. """
192
+
193
+ def clean_up_tokenization(out_string):
194
+ """ Clean up a list of simple English tokenization artifacts
195
+ like spaces before punctuations and abreviated forms.
196
+ """
197
+ out_string = (
198
+ out_string.replace(" .", ".")
199
+ .replace(" ?", "?")
200
+ .replace(" !", "!")
201
+ .replace(" ,", ",")
202
+ .replace(" ' ", "'")
203
+ .replace(" n't", "n't")
204
+ .replace(" 'm", "'m")
205
+ .replace(" 's", "'s")
206
+ .replace(" 've", "'ve")
207
+ .replace(" 're", "'re")
208
+ )
209
+ return out_string
210
+
211
+ text = ' '.join(tokens).replace(' ##', '').strip()
212
+ if clean_up_tokenization_spaces:
213
+ clean_text = clean_up_tokenization(text)
214
+ return clean_text
215
+ else:
216
+ return text
217
+
218
+ def vocab_size(self):
219
+ return len(self.vocab)
220
+
221
+
222
+ class BasicTokenizer(object):
223
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
224
+
225
+ def __init__(self, do_lower_case=True):
226
+ """Constructs a BasicTokenizer.
227
+
228
+ Args:
229
+ do_lower_case: Whether to lower case the input.
230
+ """
231
+ self.do_lower_case = do_lower_case
232
+
233
+ def tokenize(self, text):
234
+ """Tokenizes a piece of text."""
235
+ text = convert_to_unicode(text)
236
+ text = self._clean_text(text)
237
+
238
+ # This was added on November 1st, 2018 for the multilingual and Chinese
239
+ # models. This is also applied to the English models now, but it doesn't
240
+ # matter since the English models were not trained on any Chinese data
241
+ # and generally don't have any Chinese data in them (there are Chinese
242
+ # characters in the vocabulary because Wikipedia does have some Chinese
243
+ # words in the English Wikipedia.).
244
+ text = self._tokenize_chinese_chars(text)
245
+
246
+ orig_tokens = whitespace_tokenize(text)
247
+ split_tokens = []
248
+ for token in orig_tokens:
249
+ if self.do_lower_case:
250
+ token = token.lower()
251
+ token = self._run_strip_accents(token)
252
+ split_tokens.extend(self._run_split_on_punc(token))
253
+
254
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
255
+ return output_tokens
256
+
257
+ def _run_strip_accents(self, text):
258
+ """Strips accents from a piece of text."""
259
+ text = unicodedata.normalize("NFD", text)
260
+ output = []
261
+ for char in text:
262
+ cat = unicodedata.category(char)
263
+ if cat == "Mn":
264
+ continue
265
+ output.append(char)
266
+ return "".join(output)
267
+
268
+ def _run_split_on_punc(self, text):
269
+ """Splits punctuation on a piece of text."""
270
+ chars = list(text)
271
+ i = 0
272
+ start_new_word = True
273
+ output = []
274
+ while i < len(chars):
275
+ char = chars[i]
276
+ if _is_punctuation(char):
277
+ output.append([char])
278
+ start_new_word = True
279
+ else:
280
+ if start_new_word:
281
+ output.append([])
282
+ start_new_word = False
283
+ output[-1].append(char)
284
+ i += 1
285
+
286
+ return ["".join(x) for x in output]
287
+
288
+ def _tokenize_chinese_chars(self, text):
289
+ """Adds whitespace around any CJK character."""
290
+ output = []
291
+ for char in text:
292
+ cp = ord(char)
293
+ if self._is_chinese_char(cp):
294
+ output.append(" ")
295
+ output.append(char)
296
+ output.append(" ")
297
+ else:
298
+ output.append(char)
299
+ return "".join(output)
300
+
301
+ def _is_chinese_char(self, cp):
302
+ """Checks whether CP is the codepoint of a CJK character."""
303
+ # This defines a "chinese character" as anything in the CJK Unicode block:
304
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
305
+ #
306
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
307
+ # despite its name. The modern Korean Hangul alphabet is a different block,
308
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
309
+ # space-separated words, so they are not treated specially and handled
310
+ # like the all of the other languages.
311
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
312
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
313
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
314
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
315
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
316
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
317
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
318
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
319
+ return True
320
+
321
+ return False
322
+
323
+ def _clean_text(self, text):
324
+ """Performs invalid character removal and whitespace cleanup on text."""
325
+ output = []
326
+ for char in text:
327
+ cp = ord(char)
328
+ if cp == 0 or cp == 0xfffd or _is_control(char):
329
+ continue
330
+ if _is_whitespace(char):
331
+ output.append(" ")
332
+ else:
333
+ output.append(char)
334
+ return "".join(output)
335
+
336
+
337
+ class WordpieceTokenizer(object):
338
+ """Runs WordPiece tokenziation."""
339
+
340
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
341
+ self.vocab = vocab
342
+ self.unk_token = unk_token
343
+ self.max_input_chars_per_word = max_input_chars_per_word
344
+
345
+ def tokenize(self, text):
346
+ """Tokenizes a piece of text into its word pieces.
347
+
348
+ This uses a greedy longest-match-first algorithm to perform tokenization
349
+ using the given vocabulary.
350
+
351
+ For example:
352
+ input = "unaffable"
353
+ output = ["un", "##aff", "##able"]
354
+
355
+ Args:
356
+ text: A single token or whitespace separated tokens. This should have
357
+ already been passed through `BasicTokenizer.
358
+
359
+ Returns:
360
+ A list of wordpiece tokens.
361
+ """
362
+
363
+ text = convert_to_unicode(text)
364
+
365
+ output_tokens = []
366
+ for token in whitespace_tokenize(text):
367
+ chars = list(token)
368
+ if len(chars) > self.max_input_chars_per_word:
369
+ output_tokens.append(self.unk_token)
370
+ continue
371
+
372
+ is_bad = False
373
+ start = 0
374
+ sub_tokens = []
375
+ while start < len(chars):
376
+ end = len(chars)
377
+ cur_substr = None
378
+ while start < end:
379
+ substr = "".join(chars[start:end])
380
+ if start > 0:
381
+ substr = "##" + substr
382
+ if substr in self.vocab:
383
+ cur_substr = substr
384
+ break
385
+ end -= 1
386
+ if cur_substr is None:
387
+ is_bad = True
388
+ break
389
+ sub_tokens.append(cur_substr)
390
+ start = end
391
+
392
+ if is_bad:
393
+ output_tokens.append(self.unk_token)
394
+ else:
395
+ output_tokens.extend(sub_tokens)
396
+ return output_tokens
397
+
398
+
399
+ def _is_whitespace(char):
400
+ """Checks whether `chars` is a whitespace character."""
401
+ # \t, \n, and \r are technically contorl characters but we treat them
402
+ # as whitespace since they are generally considered as such.
403
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
404
+ return True
405
+ cat = unicodedata.category(char)
406
+ if cat == "Zs":
407
+ return True
408
+ return False
409
+
410
+
411
+ def _is_control(char):
412
+ """Checks whether `chars` is a control character."""
413
+ # These are technically control characters but we count them as whitespace
414
+ # characters.
415
+ if char == "\t" or char == "\n" or char == "\r":
416
+ return False
417
+ cat = unicodedata.category(char)
418
+ if cat in ("Cc", "Cf"):
419
+ return True
420
+ return False
421
+
422
+
423
+ def _is_punctuation(char):
424
+ """Checks whether `chars` is a punctuation character."""
425
+ cp = ord(char)
426
+ # We treat all non-letter/number ASCII as punctuation.
427
+ # Characters such as "^", "$", and "`" are not in the Unicode
428
+ # Punctuation class but we treat them as punctuation anyways, for
429
+ # consistency.
430
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
431
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
432
+ return True
433
+ cat = unicodedata.category(char)
434
+ if cat.startswith("P"):
435
+ return True
436
+ return False
Model/CLIP/cn_clip/clip/configuration_bert.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class BertConfig(object):
26
+ r"""
27
+ :class:`~transformers.BertConfig` is the configuration class to store the configuration of a
28
+ `BertModel`.
29
+
30
+
31
+ Arguments:
32
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
33
+ hidden_size: Size of the encoder layers and the pooler layer.
34
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads: Number of attention heads for each attention layer in
36
+ the Transformer encoder.
37
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
38
+ layer in the Transformer encoder.
39
+ hidden_act: The non-linear activation function (function or string) in the
40
+ encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
41
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
42
+ layers in the embeddings, encoder, and pooler.
43
+ attention_probs_dropout_prob: The dropout ratio for the attention
44
+ probabilities.
45
+ max_position_embeddings: The maximum sequence length that this model might
46
+ ever be used with. Typically set this to something large just in case
47
+ (e.g., 512 or 1024 or 2048).
48
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
49
+ `BertModel`.
50
+ initializer_range: The sttdev of the truncated_normal_initializer for
51
+ initializing all weight matrices.
52
+ layer_norm_eps: The epsilon used by LayerNorm.
53
+ """
54
+
55
+ def __init__(self,
56
+ vocab_size_or_config_json_file=30522,
57
+ hidden_size=768,
58
+ num_hidden_layers=12,
59
+ num_attention_heads=12,
60
+ intermediate_size=3072,
61
+ hidden_act="gelu",
62
+ hidden_dropout_prob=0.1,
63
+ attention_probs_dropout_prob=0.1,
64
+ max_position_embeddings=512,
65
+ type_vocab_size=2,
66
+ initializer_range=0.02,
67
+ layer_norm_eps=1e-12,
68
+ output_attentions=False,
69
+ output_hidden_states=False
70
+ ):
71
+ self.vocab_size = vocab_size_or_config_json_file
72
+ self.hidden_size = hidden_size
73
+ self.num_hidden_layers = num_hidden_layers
74
+ self.num_attention_heads = num_attention_heads
75
+ self.hidden_act = hidden_act
76
+ self.intermediate_size = intermediate_size
77
+ self.hidden_dropout_prob = hidden_dropout_prob
78
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
79
+ self.max_position_embeddings = max_position_embeddings
80
+ self.type_vocab_size = type_vocab_size
81
+ self.initializer_range = initializer_range
82
+ self.layer_norm_eps = layer_norm_eps
83
+ self.output_attentions = output_attentions
84
+ self.output_hidden_states = output_hidden_states
Model/CLIP/cn_clip/clip/model.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ from itertools import repeat
4
+ import collections.abc
5
+
6
+ import math
7
+ import logging
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ from cn_clip.clip import _tokenizer
15
+ from cn_clip.clip.configuration_bert import BertConfig
16
+ from cn_clip.clip.modeling_bert import BertModel
17
+
18
+
19
+ class Bottleneck(nn.Module):
20
+ expansion = 4
21
+
22
+ def __init__(self, inplanes, planes, stride=1):
23
+ super().__init__()
24
+
25
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
26
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(planes)
28
+
29
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
30
+ self.bn2 = nn.BatchNorm2d(planes)
31
+
32
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
33
+
34
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
35
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
36
+
37
+ self.relu = nn.ReLU(inplace=True)
38
+ self.downsample = None
39
+ self.stride = stride
40
+
41
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
42
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
43
+ self.downsample = nn.Sequential(OrderedDict([
44
+ ("-1", nn.AvgPool2d(stride)),
45
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
46
+ ("1", nn.BatchNorm2d(planes * self.expansion))
47
+ ]))
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ identity = x
51
+
52
+ out = self.relu(self.bn1(self.conv1(x)))
53
+ out = self.relu(self.bn2(self.conv2(out)))
54
+ out = self.avgpool(out)
55
+ out = self.bn3(self.conv3(out))
56
+
57
+ if self.downsample is not None:
58
+ identity = self.downsample(x)
59
+
60
+ out += identity
61
+ out = self.relu(out)
62
+ return out
63
+
64
+
65
+ class AttentionPool2d(nn.Module):
66
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
67
+ super().__init__()
68
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
69
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
71
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
72
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
73
+ self.num_heads = num_heads
74
+
75
+ def forward(self, x):
76
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
77
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
78
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
79
+ x, _ = F.multi_head_attention_forward(
80
+ query=x, key=x, value=x,
81
+ embed_dim_to_check=x.shape[-1],
82
+ num_heads=self.num_heads,
83
+ q_proj_weight=self.q_proj.weight,
84
+ k_proj_weight=self.k_proj.weight,
85
+ v_proj_weight=self.v_proj.weight,
86
+ in_proj_weight=None,
87
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
88
+ bias_k=None,
89
+ bias_v=None,
90
+ add_zero_attn=False,
91
+ dropout_p=0,
92
+ out_proj_weight=self.c_proj.weight,
93
+ out_proj_bias=self.c_proj.bias,
94
+ use_separate_proj_weight=True,
95
+ training=self.training,
96
+ need_weights=False
97
+ )
98
+
99
+ return x[0]
100
+
101
+
102
+ class ModifiedResNet(nn.Module):
103
+ """
104
+ A ResNet class that is similar to torchvision's but contains the following changes:
105
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
106
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
107
+ - The final pooling layer is a QKV attention instead of an average pool
108
+ """
109
+
110
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
111
+ super().__init__()
112
+ self.output_dim = output_dim
113
+ self.input_resolution = input_resolution
114
+
115
+ # the 3-layer stem
116
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(width // 2)
118
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(width // 2)
120
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(width)
122
+ self.avgpool = nn.AvgPool2d(2)
123
+ self.relu = nn.ReLU(inplace=True)
124
+
125
+ # residual layers
126
+ self._inplanes = width # this is a *mutable* variable used during construction
127
+ self.layer1 = self._make_layer(width, layers[0])
128
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
129
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
130
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
131
+
132
+ embed_dim = width * 32 # the ResNet feature dimension
133
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
134
+
135
+ def _make_layer(self, planes, blocks, stride=1):
136
+ layers = [Bottleneck(self._inplanes, planes, stride)]
137
+
138
+ self._inplanes = planes * Bottleneck.expansion
139
+ for _ in range(1, blocks):
140
+ layers.append(Bottleneck(self._inplanes, planes))
141
+
142
+ return nn.Sequential(*layers)
143
+
144
+ @torch.jit.ignore
145
+ def set_grad_checkpointing(self, enable=True):
146
+ # FIXME support for non-transformer
147
+ pass
148
+
149
+ def forward(self, x):
150
+ def stem(x):
151
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
152
+ x = self.relu(bn(conv(x)))
153
+ x = self.avgpool(x)
154
+ return x
155
+
156
+ x = x.type(self.conv1.weight.dtype)
157
+ x = stem(x)
158
+ x = self.layer1(x)
159
+ x = self.layer2(x)
160
+ x = self.layer3(x)
161
+ x = self.layer4(x)
162
+ x = self.attnpool(x)
163
+
164
+ return x
165
+
166
+
167
+ class LayerNorm(nn.LayerNorm):
168
+ """Subclass torch's LayerNorm to handle fp16."""
169
+
170
+ def forward(self, x: torch.Tensor):
171
+ orig_type = x.dtype
172
+ ret = super().forward(x.type(torch.float32))
173
+ return ret.type(orig_type)
174
+
175
+
176
+ class QuickGELU(nn.Module):
177
+ def forward(self, x: torch.Tensor):
178
+ return x * torch.sigmoid(1.702 * x)
179
+
180
+
181
+ class ResidualAttentionBlock(nn.Module):
182
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
183
+ super().__init__()
184
+
185
+ self.attn = nn.MultiheadAttention(d_model, n_head)
186
+ self.ln_1 = LayerNorm(d_model)
187
+ self.mlp = nn.Sequential(OrderedDict([
188
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
189
+ ("gelu", QuickGELU()),
190
+ ("c_proj", nn.Linear(d_model * 4, d_model))
191
+ ]))
192
+ self.ln_2 = LayerNorm(d_model)
193
+ self.attn_mask = attn_mask
194
+
195
+ def attention(self, x: torch.Tensor):
196
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
197
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
198
+
199
+ def forward(self, x: torch.Tensor):
200
+ x = x + self.attention(self.ln_1(x))
201
+ x = x + self.mlp(self.ln_2(x))
202
+ return x
203
+
204
+
205
+ class Transformer(nn.Module):
206
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
207
+ super().__init__()
208
+ self.width = width
209
+ self.layers = layers
210
+ self.grad_checkpointing = False
211
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ if self.grad_checkpointing and not torch.jit.is_scripting():
215
+ for r in self.resblocks:
216
+ x = checkpoint(r, x)
217
+ return x
218
+ return self.resblocks(x)
219
+
220
+
221
+ class VisualTransformer(nn.Module):
222
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
223
+ super().__init__()
224
+ self.input_resolution = input_resolution
225
+ self.grid_size = (self.input_resolution // patch_size, self.input_resolution // patch_size)
226
+ self.output_dim = output_dim
227
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
228
+
229
+ scale = width ** -0.5
230
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
231
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
232
+ self.ln_pre = LayerNorm(width)
233
+
234
+ self.transformer = Transformer(width, layers, heads)
235
+
236
+ self.ln_post = LayerNorm(width)
237
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
238
+
239
+ @torch.jit.ignore
240
+ def set_grad_checkpointing(self, enable=True):
241
+ self.transformer.grad_checkpointing = enable
242
+
243
+ def forward(self, x: torch.Tensor):
244
+ x = self.conv1(x) # shape = [*, width, grid, grid]
245
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
246
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
247
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
248
+ x = x + self.positional_embedding.to(x.dtype)
249
+ x = self.ln_pre(x)
250
+
251
+ x = x.permute(1, 0, 2) # NLD -> LND
252
+ x = self.transformer(x)
253
+ x = x.permute(1, 0, 2) # LND -> NLD
254
+
255
+ x = self.ln_post(x[:, 0, :])
256
+
257
+ if self.proj is not None:
258
+ x = x @ self.proj
259
+
260
+ return x
261
+
262
+
263
+ class CLIP(nn.Module):
264
+ def __init__(self,
265
+ embed_dim: int,
266
+ # vision
267
+ image_resolution: int,
268
+ vision_layers: Union[Tuple[int, int, int, int], int],
269
+ vision_width: int,
270
+ vision_patch_size: int,
271
+ # text
272
+ vocab_size: int,
273
+ text_attention_probs_dropout_prob: float,
274
+ text_hidden_act: str,
275
+ text_hidden_dropout_prob: float,
276
+ text_hidden_size: int,
277
+ text_initializer_range: float,
278
+ text_intermediate_size: int,
279
+ text_max_position_embeddings: int,
280
+ text_num_attention_heads: int,
281
+ text_num_hidden_layers: int,
282
+ text_type_vocab_size: int,
283
+ tokenizer = _tokenizer,
284
+ # vision head width, added this param for ViT-H
285
+ vision_head_width: int = 64,
286
+ ):
287
+ super().__init__()
288
+
289
+ if isinstance(vision_layers, (tuple, list)):
290
+ vision_heads = vision_width * 32 // vision_head_width
291
+ self.visual = ModifiedResNet(
292
+ layers=vision_layers,
293
+ output_dim=embed_dim,
294
+ heads=vision_heads,
295
+ input_resolution=image_resolution,
296
+ width=vision_width
297
+ )
298
+ else:
299
+ vision_heads = vision_width // vision_head_width
300
+ self.visual = VisualTransformer(
301
+ input_resolution=image_resolution,
302
+ patch_size=vision_patch_size,
303
+ width=vision_width,
304
+ layers=vision_layers,
305
+ heads=vision_heads,
306
+ output_dim=embed_dim
307
+ )
308
+
309
+ self.bert_config = BertConfig(
310
+ vocab_size_or_config_json_file=vocab_size,
311
+ hidden_size=text_hidden_size,
312
+ num_hidden_layers=text_num_hidden_layers,
313
+ num_attention_heads=text_num_attention_heads,
314
+ intermediate_size=text_intermediate_size,
315
+ hidden_act=text_hidden_act,
316
+ hidden_dropout_prob=text_hidden_dropout_prob,
317
+ attention_probs_dropout_prob=text_attention_probs_dropout_prob,
318
+ max_position_embeddings=text_max_position_embeddings,
319
+ type_vocab_size=text_type_vocab_size,
320
+ initializer_range=text_initializer_range,
321
+ layer_norm_eps=1e-12,
322
+ )
323
+ self.bert = BertModel(self.bert_config)
324
+
325
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
326
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
327
+
328
+ self.tokenizer = tokenizer
329
+
330
+ self.initialize_parameters()
331
+
332
+ def initialize_parameters(self):
333
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
334
+
335
+ if isinstance(self.visual, ModifiedResNet):
336
+ if self.visual.attnpool is not None:
337
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
338
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
339
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
340
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
341
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
342
+
343
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
344
+ for name, param in resnet_block.named_parameters():
345
+ if name.endswith("bn3.weight"):
346
+ nn.init.zeros_(param)
347
+
348
+ if self.text_projection is not None:
349
+ nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
350
+
351
+ @torch.jit.ignore
352
+ def set_grad_checkpointing(self, enable=True):
353
+ self.visual.set_grad_checkpointing(enable)
354
+ self.bert.set_grad_checkpointing(enable)
355
+
356
+ @property
357
+ def dtype(self):
358
+ return self.visual.conv1.weight.dtype
359
+
360
+ def encode_image(self, image):
361
+ return self.visual(image.type(self.dtype))
362
+
363
+ def encode_text(self, text):
364
+ pad_index = self.tokenizer.vocab['[PAD]']
365
+ attn_mask = text.ne(pad_index).type(self.dtype)
366
+ x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
367
+ return x[:, 0, :] @ self.text_projection
368
+
369
+ def forward(self, image, text):
370
+ assert image is not None or text is not None, "text and image cannot both be None!"
371
+
372
+ if image is None:
373
+ return self.encode_text(text)
374
+ elif text is None:
375
+ return self.encode_image(image)
376
+
377
+ image_features = self.encode_image(image)
378
+ text_features = self.encode_text(text)
379
+
380
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
381
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
382
+
383
+ return image_features, text_features, self.logit_scale.exp()
384
+
385
+ def get_similarity(self, image, text):
386
+ image_features = self.encode_image(image)
387
+ text_features = self.encode_text(text)
388
+
389
+ # normalized features
390
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
391
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
392
+
393
+ # cosine similarity as logits
394
+ logit_scale = self.logit_scale.exp()
395
+ logits_per_image = logit_scale * image_features @ text_features.t()
396
+ logits_per_text = logits_per_image.t()
397
+
398
+ # shape = [global_batch_size, global_batch_size]
399
+ return logits_per_image, logits_per_text
400
+
401
+
402
+ def convert_models_to_fp32(model):
403
+ for p in model.parameters():
404
+ p.data = p.data.float()
405
+ if p.grad:
406
+ p.grad.data = p.grad.data.float()
407
+
408
+
409
+ def convert_weights(model: nn.Module):
410
+ """Convert applicable model parameters to fp16"""
411
+
412
+ def _convert_weights_to_fp16(l):
413
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
414
+ l.weight.data = l.weight.data.half()
415
+ if l.bias is not None:
416
+ l.bias.data = l.bias.data.half()
417
+
418
+ if isinstance(l, nn.MultiheadAttention):
419
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
420
+ tensor = getattr(l, attr)
421
+ if tensor is not None:
422
+ tensor.data = tensor.data.half()
423
+
424
+ if isinstance(l, BertModel):
425
+ l.to(torch.half)
426
+
427
+ for name in ["text_projection", "proj"]:
428
+ if hasattr(l, name):
429
+ attr = getattr(l, name)
430
+ if attr is not None:
431
+ attr.data = attr.data.half()
432
+
433
+ model.apply(_convert_weights_to_fp16)
434
+
435
+
436
+ def restore_model(model, clip_state_dict: dict, bert_state_dict: dict):
437
+ merged_state_dict = {}
438
+
439
+ # use clip_state_dict to initialize the image encoder & logit scale
440
+ if clip_state_dict is not None:
441
+ for k, v in clip_state_dict.items():
442
+ if k.startswith("visual") or k == "logit_scale":
443
+ merged_state_dict[k] = v
444
+
445
+ # use bert_state_dict to initialize the text encoder
446
+ if bert_state_dict is not None:
447
+ for k, v in bert_state_dict.items():
448
+ if k.startswith("bert") and "bert.pooler" not in k:
449
+ merged_state_dict[k] = v
450
+
451
+ convert_weights(model)
452
+ resize_pos_embed(merged_state_dict, model)
453
+ model.load_state_dict(merged_state_dict, strict=False)
454
+ return model.eval()
455
+
456
+
457
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1, prefix=""):
458
+ # Rescale the grid of position embeddings when loading from state_dict
459
+ old_pos_embed = state_dict.get(prefix + 'visual.positional_embedding', None)
460
+ model = model.module if hasattr(model, 'module') else model
461
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
462
+ return
463
+ grid_size = to_2tuple(model.visual.grid_size)
464
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
465
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
466
+ if new_seq_len == old_pos_embed.shape[0]:
467
+ return
468
+
469
+ if extra_tokens:
470
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
471
+ else:
472
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
473
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
474
+
475
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
476
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
477
+ pos_emb_img = F.interpolate(
478
+ pos_emb_img,
479
+ size=grid_size,
480
+ mode=interpolation,
481
+ align_corners=True,
482
+ )
483
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
484
+ if pos_emb_tok is not None:
485
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
486
+ else:
487
+ new_pos_embed = pos_emb_img
488
+ state_dict[prefix + 'visual.positional_embedding'] = new_pos_embed
489
+
490
+
491
+ # From PyTorch internals
492
+ def _ntuple(n):
493
+ def parse(x):
494
+ if isinstance(x, collections.abc.Iterable):
495
+ return x
496
+ return tuple(repeat(x, n))
497
+ return parse
498
+
499
+
500
+ to_1tuple = _ntuple(1)
501
+ to_2tuple = _ntuple(2)
502
+ to_3tuple = _ntuple(3)
503
+ to_4tuple = _ntuple(4)
504
+ to_ntuple = lambda n, x: _ntuple(n)(x)
Model/CLIP/cn_clip/clip/model_configs/RBT3-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 768,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 3072,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 12,
11
+ "text_num_hidden_layers": 3,
12
+ "text_type_vocab_size": 2
13
+ }
Model/CLIP/cn_clip/clip/model_configs/RN50.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "image_resolution": 224,
4
+ "vision_layers": "[3,4,6,3]",
5
+ "vision_width": 64,
6
+ "vision_patch_size": null
7
+ }
Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 768,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 3072,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 12,
11
+ "text_num_hidden_layers": 12,
12
+ "text_type_vocab_size": 2
13
+ }
Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-large-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 1024,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 4096,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 16,
11
+ "text_num_hidden_layers": 24,
12
+ "text_type_vocab_size": 2
13
+ }
Model/CLIP/cn_clip/clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "image_resolution": 224,
4
+ "vision_layers": 12,
5
+ "vision_width": 768,
6
+ "vision_patch_size": 16
7
+ }
Model/CLIP/cn_clip/clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "image_resolution": 224,
4
+ "vision_layers": 12,
5
+ "vision_width": 768,
6
+ "vision_patch_size": 32
7
+ }
Model/CLIP/cn_clip/clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "image_resolution": 224,
4
+ "vision_layers": 32,
5
+ "vision_width": 1280,
6
+ "vision_head_width": 80,
7
+ "vision_patch_size": 14
8
+ }
Model/CLIP/cn_clip/clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "image_resolution": 336,
4
+ "vision_layers": 24,
5
+ "vision_width": 1024,
6
+ "vision_patch_size": 14
7
+ }
Model/CLIP/cn_clip/clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "image_resolution": 224,
4
+ "vision_layers": 24,
5
+ "vision_width": 1024,
6
+ "vision_patch_size": 14
7
+ }
Model/CLIP/cn_clip/clip/model_configs/for_learn.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import os
4
+
5
+ vision_model = "ViT-B-16"
6
+ vision_model_config_file = \
7
+ Path(__file__).parent / f"{vision_model.replace('/', '-')}.json"
8
+ print('Loading vision model config from', vision_model_config_file)
9
+ assert os.path.exists(vision_model_config_file)
10
+ with open(vision_model_config_file, 'r') as fv:
11
+ model_info = json.load(fv).items()
12
+ print('Model info:', model_info)
13
+ if isinstance(model_info['vision_layers'], str):
14
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
15
+
16
+
Model/CLIP/cn_clip/clip/modeling_bert.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ from io import open
26
+
27
+ import torch
28
+ from torch import nn
29
+ from torch.utils.checkpoint import checkpoint
30
+
31
+ from .configuration_bert import BertConfig
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ def gelu(x):
36
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
37
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
38
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
39
+ Also see https://arxiv.org/abs/1606.08415
40
+ """
41
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
42
+
43
+ def gelu_new(x):
44
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
45
+ Also see https://arxiv.org/abs/1606.08415
46
+ """
47
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
48
+
49
+ def swish(x):
50
+ return x * torch.sigmoid(x)
51
+
52
+
53
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
54
+
55
+
56
+ BertLayerNorm = torch.nn.LayerNorm
57
+
58
+ class BertEmbeddings(nn.Module):
59
+ """Construct the embeddings from word, position and token_type embeddings.
60
+ """
61
+ def __init__(self, config):
62
+ super(BertEmbeddings, self).__init__()
63
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
64
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
65
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
66
+
67
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
68
+ # any TensorFlow checkpoint file
69
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
70
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
71
+
72
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
73
+ seq_length = input_ids.size(1)
74
+ if position_ids is None:
75
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
76
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
77
+ if token_type_ids is None:
78
+ token_type_ids = torch.zeros_like(input_ids)
79
+
80
+ words_embeddings = self.word_embeddings(input_ids)
81
+ position_embeddings = self.position_embeddings(position_ids)
82
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
83
+
84
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
85
+ embeddings = self.LayerNorm(embeddings)
86
+ embeddings = self.dropout(embeddings)
87
+ return embeddings
88
+
89
+
90
+ class BertSelfAttention(nn.Module):
91
+ def __init__(self, config):
92
+ super(BertSelfAttention, self).__init__()
93
+ if config.hidden_size % config.num_attention_heads != 0:
94
+ raise ValueError(
95
+ "The hidden size (%d) is not a multiple of the number of attention "
96
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
97
+ self.output_attentions = config.output_attentions
98
+
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
102
+
103
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
105
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
106
+
107
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
108
+
109
+ def transpose_for_scores(self, x):
110
+
111
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
112
+ x = x.view(*new_x_shape)
113
+ return x.permute(0, 2, 1, 3)
114
+
115
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
116
+ mixed_query_layer = self.query(hidden_states)
117
+ mixed_key_layer = self.key(hidden_states)
118
+ mixed_value_layer = self.value(hidden_states)
119
+
120
+ query_layer = self.transpose_for_scores(mixed_query_layer)
121
+ key_layer = self.transpose_for_scores(mixed_key_layer)
122
+ value_layer = self.transpose_for_scores(mixed_value_layer)
123
+
124
+ # Take the dot product between "query" and "key" to get the raw attention scores.
125
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
126
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
127
+ if attention_mask is not None:
128
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
129
+ attention_scores = attention_scores + attention_mask
130
+
131
+ # Normalize the attention scores to probabilities.
132
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
133
+
134
+ # This is actually dropping out entire tokens to attend to, which might
135
+ # seem a bit unusual, but is taken from the original Transformer paper.
136
+ attention_probs = self.dropout(attention_probs)
137
+
138
+ # Mask heads if we want to
139
+ if head_mask is not None:
140
+ attention_probs = attention_probs * head_mask
141
+
142
+ context_layer = torch.matmul(attention_probs, value_layer)
143
+
144
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
145
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
146
+ context_layer = context_layer.view(*new_context_layer_shape)
147
+
148
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
149
+ return outputs
150
+
151
+
152
+ class BertSelfOutput(nn.Module):
153
+ def __init__(self, config):
154
+ super(BertSelfOutput, self).__init__()
155
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
156
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
157
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
158
+
159
+ def forward(self, hidden_states, input_tensor):
160
+ hidden_states = self.dense(hidden_states)
161
+ hidden_states = self.dropout(hidden_states)
162
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
163
+ return hidden_states
164
+
165
+
166
+ class BertAttention(nn.Module):
167
+ def __init__(self, config):
168
+ super(BertAttention, self).__init__()
169
+ self.self = BertSelfAttention(config)
170
+ self.output = BertSelfOutput(config)
171
+ self.pruned_heads = set()
172
+
173
+ def forward(self, input_tensor, attention_mask=None, head_mask=None):
174
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
175
+ attention_output = self.output(self_outputs[0], input_tensor)
176
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
177
+ return outputs
178
+
179
+
180
+ class BertIntermediate(nn.Module):
181
+ def __init__(self, config):
182
+ super(BertIntermediate, self).__init__()
183
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
184
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
185
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
186
+ else:
187
+ self.intermediate_act_fn = config.hidden_act
188
+
189
+ def forward(self, hidden_states):
190
+ hidden_states = self.dense(hidden_states)
191
+ hidden_states = self.intermediate_act_fn(hidden_states)
192
+ return hidden_states
193
+
194
+
195
+ class BertOutput(nn.Module):
196
+ def __init__(self, config):
197
+ super(BertOutput, self).__init__()
198
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
199
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
200
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
201
+
202
+ def forward(self, hidden_states, input_tensor):
203
+ hidden_states = self.dense(hidden_states)
204
+ hidden_states = self.dropout(hidden_states)
205
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
206
+ return hidden_states
207
+
208
+
209
+ class BertLayer(nn.Module):
210
+ def __init__(self, config):
211
+ super(BertLayer, self).__init__()
212
+ self.attention = BertAttention(config)
213
+ self.intermediate = BertIntermediate(config)
214
+ self.output = BertOutput(config)
215
+
216
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
217
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
218
+ attention_output = attention_outputs[0]
219
+ intermediate_output = self.intermediate(attention_output)
220
+ layer_output = self.output(intermediate_output, attention_output)
221
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
222
+ if len(outputs) == 1:
223
+ return outputs[0]
224
+ return outputs
225
+
226
+
227
+ class BertEncoder(nn.Module):
228
+ def __init__(self, config):
229
+ super(BertEncoder, self).__init__()
230
+ self.output_attentions = config.output_attentions
231
+ self.output_hidden_states = config.output_hidden_states
232
+ self.grad_checkpointing = False
233
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
234
+
235
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
236
+ all_hidden_states = ()
237
+ all_attentions = ()
238
+ for i, layer_module in enumerate(self.layer):
239
+ if self.output_hidden_states:
240
+ all_hidden_states = all_hidden_states + (hidden_states,)
241
+
242
+ if self.grad_checkpointing and not torch.jit.is_scripting():
243
+ layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i])
244
+ else:
245
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
246
+ if not isinstance(layer_outputs, tuple):
247
+ layer_outputs = (layer_outputs, )
248
+ hidden_states = layer_outputs[0]
249
+
250
+ if self.output_attentions:
251
+ all_attentions = all_attentions + (layer_outputs[1],)
252
+
253
+ # Add last layer
254
+ if self.output_hidden_states:
255
+ all_hidden_states = all_hidden_states + (hidden_states,)
256
+
257
+ outputs = (hidden_states,)
258
+ if self.output_hidden_states:
259
+ outputs = outputs + (all_hidden_states,)
260
+ if self.output_attentions:
261
+ outputs = outputs + (all_attentions,)
262
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
263
+
264
+
265
+ class BertPooler(nn.Module):
266
+ def __init__(self, config):
267
+ super(BertPooler, self).__init__()
268
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
269
+ self.activation = nn.Tanh()
270
+
271
+ def forward(self, hidden_states):
272
+ # We "pool" the model by simply taking the hidden state corresponding
273
+ # to the first token.
274
+ first_token_tensor = hidden_states[:, 0]
275
+ pooled_output = self.dense(first_token_tensor)
276
+ pooled_output = self.activation(pooled_output)
277
+ return pooled_output
278
+
279
+
280
+ class BertPredictionHeadTransform(nn.Module):
281
+ def __init__(self, config):
282
+ super(BertPredictionHeadTransform, self).__init__()
283
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
284
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
285
+ self.transform_act_fn = ACT2FN[config.hidden_act]
286
+ else:
287
+ self.transform_act_fn = config.hidden_act
288
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
289
+
290
+ def forward(self, hidden_states):
291
+ hidden_states = self.dense(hidden_states)
292
+ hidden_states = self.transform_act_fn(hidden_states)
293
+ hidden_states = self.LayerNorm(hidden_states)
294
+ return hidden_states
295
+
296
+
297
+ class BertLMPredictionHead(nn.Module):
298
+ def __init__(self, config):
299
+ super(BertLMPredictionHead, self).__init__()
300
+ self.transform = BertPredictionHeadTransform(config)
301
+
302
+ # The output weights are the same as the input embeddings, but there is
303
+ # an output-only bias for each token.
304
+ self.decoder = nn.Linear(config.hidden_size,
305
+ config.vocab_size,
306
+ bias=False)
307
+
308
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
309
+
310
+ def forward(self, hidden_states):
311
+ hidden_states = self.transform(hidden_states)
312
+ hidden_states = self.decoder(hidden_states) + self.bias
313
+ return hidden_states
314
+
315
+
316
+ class BertOnlyMLMHead(nn.Module):
317
+ def __init__(self, config):
318
+ super(BertOnlyMLMHead, self).__init__()
319
+ self.predictions = BertLMPredictionHead(config)
320
+
321
+ def forward(self, sequence_output):
322
+ prediction_scores = self.predictions(sequence_output)
323
+ return prediction_scores
324
+
325
+
326
+ class BertOnlyNSPHead(nn.Module):
327
+ def __init__(self, config):
328
+ super(BertOnlyNSPHead, self).__init__()
329
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
330
+
331
+ def forward(self, pooled_output):
332
+ seq_relationship_score = self.seq_relationship(pooled_output)
333
+ return seq_relationship_score
334
+
335
+
336
+ class BertPreTrainingHeads(nn.Module):
337
+ def __init__(self, config):
338
+ super(BertPreTrainingHeads, self).__init__()
339
+ self.predictions = BertLMPredictionHead(config)
340
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
341
+
342
+ def forward(self, sequence_output, pooled_output):
343
+ prediction_scores = self.predictions(sequence_output)
344
+ seq_relationship_score = self.seq_relationship(pooled_output)
345
+ return prediction_scores, seq_relationship_score
346
+
347
+
348
+ class BertPreTrainedModel(nn.Module):
349
+ config_class = BertConfig
350
+ base_model_prefix = "bert"
351
+
352
+ def __init__(self, config):
353
+ super(BertPreTrainedModel, self).__init__()
354
+ self.config = config
355
+
356
+ def _init_weights(self, module):
357
+ """ Initialize the weights """
358
+ if isinstance(module, (nn.Linear, nn.Embedding)):
359
+ # Slightly different from the TF version which uses truncated_normal for initialization
360
+ # cf https://github.com/pytorch/pytorch/pull/5617
361
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
362
+ elif isinstance(module, BertLayerNorm):
363
+ module.bias.data.zero_()
364
+ module.weight.data.fill_(1.0)
365
+ if isinstance(module, nn.Linear) and module.bias is not None:
366
+ module.bias.data.zero_()
367
+
368
+
369
+ class BertModel(BertPreTrainedModel):
370
+ r"""
371
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
372
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
373
+ Sequence of hidden-states at the output of the last layer of the model.
374
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
375
+ Last layer hidden-state of the first token of the sequence (classification token)
376
+ further processed by a Linear layer and a Tanh activation function. The Linear
377
+ layer weights are trained from the next sentence prediction (classification)
378
+ objective during Bert pretraining. This output is usually *not* a good summary
379
+ of the semantic content of the input, you're often better with averaging or pooling
380
+ the sequence of hidden-states for the whole input sequence.
381
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
382
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
383
+ of shape ``(batch_size, sequence_length, hidden_size)``:
384
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
385
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
386
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
387
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
388
+
389
+ Examples::
390
+
391
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
392
+ model = BertModel.from_pretrained('bert-base-uncased')
393
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
394
+ outputs = model(input_ids)
395
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
396
+
397
+ """
398
+ def __init__(self, config):
399
+ super(BertModel, self).__init__(config)
400
+
401
+ self.embeddings = BertEmbeddings(config)
402
+ self.encoder = BertEncoder(config)
403
+ # self.pooler = BertPooler(config)
404
+
405
+ self.apply(self._init_weights)
406
+
407
+ @torch.jit.ignore
408
+ def set_grad_checkpointing(self, enable=True):
409
+ if enable:
410
+ assert not self.config.output_attentions, \
411
+ "Grad checkpointing is currently conflict with output_attentions for BertEncoder, \
412
+ please set it to False in BertConfig"
413
+ self.encoder.grad_checkpointing = enable
414
+
415
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
416
+ if attention_mask is None:
417
+ attention_mask = torch.ones_like(input_ids)
418
+ if token_type_ids is None:
419
+ token_type_ids = torch.zeros_like(input_ids)
420
+
421
+ # We create a 3D attention mask from a 2D tensor mask.
422
+ # Sizes are [batch_size, 1, 1, to_seq_length]
423
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
424
+ # this attention mask is more simple than the triangular masking of causal attention
425
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
426
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
427
+
428
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
429
+ # masked positions, this operation will create a tensor which is 0.0 for
430
+ # positions we want to attend and -10000.0 for masked positions.
431
+ # Since we are adding it to the raw scores before the softmax, this is
432
+ # effectively the same as removing these entirely.
433
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
434
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
435
+
436
+ # Prepare head mask if needed
437
+ # 1.0 in head_mask indicate we keep the head
438
+ # attention_probs has shape bsz x n_heads x N x N
439
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
440
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
441
+ if head_mask is not None:
442
+ if head_mask.dim() == 1:
443
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
444
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
445
+ elif head_mask.dim() == 2:
446
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
447
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
448
+ else:
449
+ head_mask = [None] * self.config.num_hidden_layers
450
+
451
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
452
+ encoder_outputs = self.encoder(embedding_output,
453
+ extended_attention_mask,
454
+ head_mask=head_mask)
455
+ sequence_output = encoder_outputs[0]
456
+ # pooled_output = self.pooler(sequence_output)
457
+ pooled_output = None
458
+
459
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
460
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
Model/CLIP/cn_clip/clip/utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from https://github.com/openai/CLIP
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Union, List
7
+ import urllib
8
+
9
+ import torch
10
+ from torchvision.transforms import Compose, ToTensor, Normalize, Resize
11
+ from tqdm import tqdm
12
+
13
+ from cn_clip.clip import _tokenizer
14
+ from cn_clip.clip.model import convert_weights, CLIP, restore_model
15
+
16
+ __all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
17
+
18
+ _MODELS = {
19
+ "ViT-B-16": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt",
20
+ "ViT-L-14": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14.pt",
21
+ "ViT-L-14-336": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14-336.pt",
22
+ "ViT-H-14": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-h-14.pt",
23
+ "RN50": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_rn50.pt",
24
+ }
25
+ _MODEL_INFO = {
26
+ "ViT-B-16": {
27
+ "struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
28
+ "input_resolution": 224
29
+ },
30
+ "ViT-L-14": {
31
+ "struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
32
+ "input_resolution": 224
33
+ },
34
+ "ViT-L-14-336": {
35
+ "struct": "ViT-L-14-336@RoBERTa-wwm-ext-base-chinese",
36
+ "input_resolution": 336
37
+ },
38
+ "ViT-H-14": {
39
+ "struct": "ViT-H-14@RoBERTa-wwm-ext-large-chinese",
40
+ "input_resolution": 224
41
+ },
42
+ "RN50": {
43
+ "struct": "RN50@RBT3-chinese",
44
+ "input_resolution": 224
45
+ },
46
+ }
47
+
48
+
49
+ def _download(url: str, root: str):
50
+ os.makedirs(root, exist_ok=True)
51
+ filename = os.path.basename(url)
52
+
53
+ download_target = os.path.join(root, filename)
54
+
55
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
56
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
57
+
58
+ if os.path.isfile(download_target):
59
+ return download_target
60
+
61
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
62
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
63
+ unit_divisor=1024) as loop:
64
+ while True:
65
+ buffer = source.read(8192)
66
+ if not buffer:
67
+ break
68
+
69
+ output.write(buffer)
70
+ loop.update(len(buffer))
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def available_models() -> List[str]:
80
+ """Returns the names of available CLIP models"""
81
+ return list(_MODELS.keys())
82
+
83
+
84
+ def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
85
+ download_root: str = None, resume: str = None):
86
+ if resume is not None:
87
+ model_path = resume
88
+ elif name in _MODELS:
89
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
90
+ elif os.path.isfile(name):
91
+ model_path = name
92
+ else:
93
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
94
+
95
+ with open(model_path, 'rb') as opened_file:
96
+ # loading saved checkpoint
97
+ checkpoint = torch.load(opened_file, map_location="cpu")
98
+
99
+ model = create_model(_MODEL_INFO[name]['struct'], checkpoint)
100
+ if str(device) == "cpu":
101
+ model.float()
102
+ else:
103
+ model.to(device)
104
+
105
+ return model, image_transform(_MODEL_INFO[name]['input_resolution'])
106
+
107
+
108
+ def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
109
+ bert_path=None):
110
+ """Load CLIP and BERT model weights
111
+ """
112
+
113
+ bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
114
+ clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
115
+
116
+ restore_model(model, clip_state_dict, bert_state_dict).to(device)
117
+
118
+ if str(device) == "cpu":
119
+ model.float()
120
+ return model
121
+
122
+
123
+ def tokenize(texts: Union[str, List[str]], context_length: int = 64) -> torch.LongTensor:
124
+ """
125
+ Returns the tokenized representation of given input string(s)
126
+ Parameters
127
+ ----------
128
+ texts : Union[str, List[str]]
129
+ An input string or a list of input strings to tokenize
130
+ context_length : int
131
+ The context length to use; all baseline models use 24 as the context length
132
+ Returns
133
+ -------
134
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
135
+ """
136
+ if isinstance(texts, str):
137
+ texts = [texts]
138
+
139
+ all_tokens = []
140
+ for text in texts:
141
+ all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
142
+ :context_length - 2] + [_tokenizer.vocab['[SEP]']])
143
+
144
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
145
+
146
+ for i, tokens in enumerate(all_tokens):
147
+ assert len(tokens) <= context_length
148
+ result[i, :len(tokens)] = torch.tensor(tokens)
149
+
150
+ return result
151
+
152
+
153
+ def _convert_to_rgb(image):
154
+ return image.convert('RGB')
155
+
156
+
157
+ def image_transform(image_size=224):
158
+ transform = Compose([
159
+ _convert_to_rgb,
160
+ Resize((image_size, image_size)),
161
+ ToTensor(),
162
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
163
+ ])
164
+ return transform
165
+
166
+
167
+ def create_model(model_name, checkpoint=None):
168
+ vision_model, text_model = model_name.split('@')
169
+ # Initialize the model.
170
+ vision_model_config_file = Path(
171
+ __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
172
+ # print('Loading vision model config from', vision_model_config_file)
173
+ assert os.path.exists(vision_model_config_file)
174
+
175
+ text_model_config_file = Path(
176
+ __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
177
+ # print('Loading text model config from', text_model_config_file)
178
+ assert os.path.exists(text_model_config_file)
179
+
180
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
181
+ model_info = json.load(fv)
182
+ for k, v in json.load(ft).items():
183
+ model_info[k] = v
184
+ if isinstance(model_info['vision_layers'], str):
185
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
186
+ # print('Model info', model_info)
187
+ model = CLIP(**model_info)
188
+ convert_weights(model)
189
+ if checkpoint:
190
+ sd = checkpoint["state_dict"]
191
+ if next(iter(sd.items()))[0].startswith('module'):
192
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
193
+ model.load_state_dict(sd)
194
+ return model
195
+
196
+
Model/CLIP/cn_clip/clip/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
Model/CLIP/cn_clip/eval/__init__.py ADDED
File without changes
Model/CLIP/cn_clip/eval/data.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ from PIL import Image
8
+ import base64
9
+ from io import BytesIO
10
+
11
+ import lmdb
12
+
13
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
14
+
15
+ import torch
16
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.utils.data.sampler import SequentialSampler
19
+
20
+ import torchvision.datasets as datasets
21
+
22
+ from cn_clip.clip import tokenize
23
+
24
+ def _convert_to_rgb(image):
25
+ return image.convert('RGB')
26
+
27
+
28
+ def _preprocess_text(text):
29
+ # adapt the text to Chinese BERT vocab
30
+ text = text.lower().replace("“", "\"").replace("”", "\"")
31
+ return text
32
+
33
+ class EvalTxtDataset(Dataset):
34
+ def __init__(self, jsonl_filename, max_txt_length=24):
35
+ assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename)
36
+
37
+ logging.debug(f'Loading jsonl data from {jsonl_filename}.')
38
+ self.texts = []
39
+ with open(jsonl_filename, "r") as fin:
40
+ for line in fin:
41
+ obj = json.loads(line.strip())
42
+ text_id = obj['text_id']
43
+ text = obj['text']
44
+ self.texts.append((text_id, text))
45
+ logging.debug(f'Finished loading jsonl data from {jsonl_filename}.')
46
+
47
+ self.max_txt_length = max_txt_length
48
+
49
+ def __len__(self):
50
+ return len(self.texts)
51
+
52
+ def __getitem__(self, idx):
53
+ text_id, text = self.texts[idx]
54
+ text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0]
55
+ return text_id, text
56
+
57
+ class EvalImgDataset(Dataset):
58
+ def __init__(self, lmdb_imgs, resolution=224):
59
+ assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs)
60
+
61
+ logging.debug(f'Loading image LMDB from {lmdb_imgs}.')
62
+
63
+ self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
64
+ self.txn_imgs = self.env_imgs.begin(buffers=True)
65
+ self.cursor_imgs = self.txn_imgs.cursor()
66
+ self.iter_imgs = iter(self.cursor_imgs)
67
+ self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8'))
68
+ logging.info("The specified LMDB directory contains {} images.".format(self.number_images))
69
+
70
+ self.transform = self._build_transform(resolution)
71
+
72
+ def _build_transform(self, resolution):
73
+ normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
74
+ return Compose([
75
+ Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC),
76
+ _convert_to_rgb,
77
+ ToTensor(),
78
+ normalize,
79
+ ])
80
+
81
+ def __len__(self):
82
+ return self.number_images
83
+
84
+ def __getitem__(self, idx):
85
+ img_id, image_b64 = next(self.iter_imgs)
86
+ if img_id == b"num_images":
87
+ img_id, image_b64 = next(self.iter_imgs)
88
+
89
+ img_id = img_id.tobytes()
90
+ image_b64 = image_b64.tobytes()
91
+
92
+ img_id = int(img_id.decode(encoding="utf8", errors="ignore"))
93
+ image_b64 = image_b64.decode(encoding="utf8", errors="ignore")
94
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) # already resized
95
+ image = self.transform(image)
96
+
97
+ return img_id, image
98
+
99
+ @dataclass
100
+ class DataInfo:
101
+ dataloader: DataLoader
102
+ sampler: DistributedSampler
103
+
104
+ def get_eval_txt_dataset(args, max_txt_length=24):
105
+ input_filename = args.text_data
106
+ dataset = EvalTxtDataset(
107
+ input_filename,
108
+ max_txt_length=max_txt_length)
109
+ num_samples = len(dataset)
110
+ sampler = SequentialSampler(dataset)
111
+
112
+ dataloader = DataLoader(
113
+ dataset,
114
+ batch_size=args.text_batch_size,
115
+ num_workers=0,
116
+ pin_memory=True,
117
+ sampler=sampler,
118
+ drop_last=False,
119
+ )
120
+ dataloader.num_samples = num_samples
121
+ dataloader.num_batches = len(dataloader)
122
+
123
+ return DataInfo(dataloader, sampler)
124
+
125
+ def fetch_resolution(vision_model):
126
+ # fetch the resolution from the vision model config
127
+ vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json"
128
+ with open(vision_model_config_file, 'r') as fv:
129
+ model_info = json.load(fv)
130
+ return model_info["image_resolution"]
131
+
132
+ def get_eval_img_dataset(args):
133
+ lmdb_imgs = args.image_data
134
+ dataset = EvalImgDataset(
135
+ lmdb_imgs, resolution=fetch_resolution(args.vision_model))
136
+ num_samples = len(dataset)
137
+ sampler = SequentialSampler(dataset)
138
+
139
+ dataloader = DataLoader(
140
+ dataset,
141
+ batch_size=args.img_batch_size,
142
+ num_workers=0,
143
+ pin_memory=True,
144
+ sampler=sampler,
145
+ drop_last=False,
146
+ )
147
+ dataloader.num_samples = num_samples
148
+ dataloader.num_batches = len(dataloader)
149
+
150
+ return DataInfo(dataloader, sampler)
151
+
152
+ def get_imagenet_dataset(args, preprocess_fn, split):
153
+ assert split in ["val"]
154
+
155
+ data_path = args.imagenet_val
156
+ assert data_path
157
+
158
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
159
+
160
+ dataloader = torch.utils.data.DataLoader(
161
+ dataset,
162
+ batch_size=args.img_batch_size,
163
+ num_workers=args.num_workers,
164
+ sampler=None,
165
+ )
166
+
167
+ return DataInfo(dataloader, None)
Model/CLIP/cn_clip/eval/evaluation.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script computes the recall scores given the ground-truth annotations and predictions.
4
+ '''
5
+
6
+ import json
7
+ import sys
8
+ import os
9
+ import string
10
+ import numpy as np
11
+ import time
12
+
13
+ NUM_K = 10
14
+
15
+ def read_submission(submit_path, reference, k=5):
16
+ # check whether the path of submitted file exists
17
+ if not os.path.exists(submit_path):
18
+ raise Exception("The submission file is not found!")
19
+
20
+ submission_dict = {}
21
+ ref_qids = set(reference.keys())
22
+
23
+ with open(submit_path) as fin:
24
+ for line in fin:
25
+ line = line.strip()
26
+ try:
27
+ pred_obj = json.loads(line)
28
+ except:
29
+ raise Exception('Cannot parse this line into json object: {}'.format(line))
30
+ if "text_id" not in pred_obj:
31
+ raise Exception('There exists one line not containing text_id: {}'.format(line))
32
+ if not isinstance(pred_obj['text_id'], int):
33
+ raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
34
+ qid = pred_obj["text_id"]
35
+ if "image_ids" not in pred_obj:
36
+ raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
37
+ image_ids = pred_obj["image_ids"]
38
+ if not isinstance(image_ids, list):
39
+ raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
40
+ # check whether there are K products for each text
41
+ if len(image_ids) != k:
42
+ raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))
43
+ # check whether there exist an invalid prediction for any text
44
+ for rank, image_id in enumerate(image_ids):
45
+ if not isinstance(image_id, int):
46
+ raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
47
+ # check whether there are duplicate predicted products for a single text
48
+ if len(set(image_ids)) != k:
49
+ raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
50
+ submission_dict[qid] = image_ids # here we save the list of product ids
51
+
52
+ # check if any text is missing in the submission
53
+ pred_qids = set(submission_dict.keys())
54
+ nopred_qids = ref_qids - pred_qids
55
+ if len(nopred_qids) != 0:
56
+ raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))
57
+
58
+ return submission_dict
59
+
60
+
61
+ def dump_2_json(info, path):
62
+ with open(path, 'w') as output_json_file:
63
+ json.dump(info, output_json_file)
64
+
65
+
66
+ def report_error_msg(detail, showMsg, out_p):
67
+ error_dict=dict()
68
+ error_dict['errorDetail']=detail
69
+ error_dict['errorMsg']=showMsg
70
+ error_dict['score']=0
71
+ error_dict['scoreJson']={}
72
+ error_dict['success']=False
73
+ dump_2_json(error_dict,out_p)
74
+
75
+
76
+ def report_score(r1, r5, r10, out_p):
77
+ result = dict()
78
+ result['success']=True
79
+ mean_recall = (r1 + r5 + r10) / 3.0
80
+ result['score'] = mean_recall * 100
81
+ result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
82
+ dump_2_json(result,out_p)
83
+
84
+
85
+ def read_reference(path):
86
+ fin = open(path)
87
+ reference = dict()
88
+ for line in fin:
89
+ line = line.strip()
90
+ obj = json.loads(line)
91
+ reference[obj['text_id']] = obj['image_ids']
92
+ return reference
93
+
94
+ def compute_score(golden_file, predict_file):
95
+ # read ground-truth
96
+ reference = read_reference(golden_file)
97
+
98
+ # read predictions
99
+ k = 10
100
+ predictions = read_submission(predict_file, reference, k)
101
+
102
+ # compute score for each text
103
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
104
+ for qid in reference.keys():
105
+ ground_truth_ids = set(reference[qid])
106
+ top10_pred_ids = predictions[qid]
107
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
108
+ r1_stat += 1
109
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
110
+ r5_stat += 1
111
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
112
+ r10_stat += 1
113
+ # the higher score, the better
114
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
115
+ mean_recall = (r1 + r5 + r10) / 3.0
116
+ result = [mean_recall, r1, r5, r10]
117
+ result = [score * 100 for score in result]
118
+ return result
119
+
120
+
121
+ if __name__=="__main__":
122
+ # the path of answer json file (eg. test_queries_answers.jsonl)
123
+ standard_path = sys.argv[1]
124
+ # the path of prediction file (eg. example_pred.jsonl)
125
+ submit_path = sys.argv[2]
126
+ # the score will be dumped into this output json file
127
+ out_path = sys.argv[3]
128
+
129
+ print("Read standard from %s" % standard_path)
130
+ print("Read user submit file from %s" % submit_path)
131
+
132
+ try:
133
+ # read ground-truth
134
+ reference = read_reference(standard_path)
135
+
136
+ # read predictions
137
+ k = 10
138
+ predictions = read_submission(submit_path, reference, k)
139
+
140
+ # compute score for each text
141
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
142
+ for qid in reference.keys():
143
+ ground_truth_ids = set(reference[qid])
144
+ top10_pred_ids = predictions[qid]
145
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
146
+ r1_stat += 1
147
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
148
+ r5_stat += 1
149
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
150
+ r10_stat += 1
151
+ # the higher score, the better
152
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
153
+ report_score(r1, r5, r10, out_path)
154
+ print("The evaluation finished successfully.")
155
+ except Exception as e:
156
+ report_error_msg(e.args[0], e.args[0], out_path)
157
+ print("The evaluation failed: {}".format(e.args[0]))
Model/CLIP/cn_clip/eval/evaluation_tr.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script computes the recall scores given the ground-truth annotations and predictions.
4
+ '''
5
+
6
+ import json
7
+ import sys
8
+ import os
9
+ import string
10
+ import numpy as np
11
+ import time
12
+
13
+ NUM_K = 10
14
+
15
+ def read_submission(submit_path, reference, k=5):
16
+ # check whether the path of submitted file exists
17
+ if not os.path.exists(submit_path):
18
+ raise Exception("The submission file is not found!")
19
+
20
+ submission_dict = {}
21
+ ref_image_ids = set(reference.keys())
22
+
23
+ with open(submit_path) as fin:
24
+ for line in fin:
25
+ line = line.strip()
26
+ try:
27
+ pred_obj = json.loads(line)
28
+ except:
29
+ raise Exception('Cannot parse this line into json object: {}'.format(line))
30
+ if "image_id" not in pred_obj:
31
+ raise Exception('There exists one line not containing image_id: {}'.format(line))
32
+ if not isinstance(pred_obj['image_id'], int):
33
+ raise Exception('Found an invalid image_id {}, it should be an integer (not string), please check your schema'.format(pred_obj['image_id']))
34
+ image_id = pred_obj['image_id']
35
+ if "text_ids" not in pred_obj:
36
+ raise Exception('There exists one line not containing the predicted text_ids: {}'.format(line))
37
+ text_ids = pred_obj["text_ids"]
38
+ if not isinstance(text_ids, list):
39
+ raise Exception('The text_ids field of image_id {} is not a list, please check your schema'.format(image_id))
40
+ # check whether there are K products for each text
41
+ if len(text_ids) != k:
42
+ raise Exception('Image_id {} has wrong number of predicted text_ids! Require {}, but {} founded.'.format(image_id, k, len(text_ids)))
43
+ # check whether there exist an invalid prediction for any text
44
+ for rank, text_id in enumerate(text_ids):
45
+ if not isinstance(text_id, int):
46
+ raise Exception('Image_id {} has an invalid predicted text_id {} at rank {}, it should be an integer (not string), please check your schema'.format(image_id, text_id, rank + 1))
47
+ # check whether there are duplicate predicted products for a single text
48
+ if len(set(text_ids)) != k:
49
+ raise Exception('Image_id {} has duplicate products in your prediction. Pleace check again!'.format(image_id))
50
+ submission_dict[image_id] = text_ids # here we save the list of product ids
51
+
52
+ # check if any text is missing in the submission
53
+ pred_image_ids = set(submission_dict.keys())
54
+ nopred_image_ids = ref_image_ids - pred_image_ids
55
+ if len(nopred_image_ids) != 0:
56
+ raise Exception('The following image_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_image_ids])))
57
+
58
+ return submission_dict
59
+
60
+
61
+ def dump_2_json(info, path):
62
+ with open(path, 'w') as output_json_file:
63
+ json.dump(info, output_json_file)
64
+
65
+
66
+ def report_error_msg(detail, showMsg, out_p):
67
+ error_dict=dict()
68
+ error_dict['errorDetail']=detail
69
+ error_dict['errorMsg']=showMsg
70
+ error_dict['score']=0
71
+ error_dict['scoreJson']={}
72
+ error_dict['success']=False
73
+ dump_2_json(error_dict,out_p)
74
+
75
+
76
+ def report_score(r1, r5, r10, out_p):
77
+ result = dict()
78
+ result['success']=True
79
+ mean_recall = (r1 + r5 + r10) / 3.0
80
+ result['score'] = mean_recall * 100
81
+ result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
82
+ dump_2_json(result,out_p)
83
+
84
+
85
+ def read_reference(path):
86
+ fin = open(path)
87
+ reference = dict()
88
+ for line in fin:
89
+ line = line.strip()
90
+ obj = json.loads(line)
91
+ reference[obj['image_id']] = obj['text_ids']
92
+ return reference
93
+
94
+ def compute_score(golden_file, predict_file):
95
+ # read ground-truth
96
+ reference = read_reference(golden_file)
97
+
98
+ # read predictions
99
+ k = 10
100
+ predictions = read_submission(predict_file, reference, k)
101
+
102
+ # compute score for each text
103
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
104
+ for qid in reference.keys():
105
+ ground_truth_ids = set(reference[qid])
106
+ top10_pred_ids = predictions[qid]
107
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
108
+ r1_stat += 1
109
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
110
+ r5_stat += 1
111
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
112
+ r10_stat += 1
113
+ # the higher score, the better
114
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
115
+ mean_recall = (r1 + r5 + r10) / 3.0
116
+ result = [mean_recall, r1, r5, r10]
117
+ result = [score * 100 for score in result]
118
+ return result
119
+
120
+
121
+ if __name__=="__main__":
122
+ # the path of answer json file (eg. test_queries_answers.jsonl)
123
+ standard_path = sys.argv[1]
124
+ # the path of prediction file (eg. example_pred.jsonl)
125
+ submit_path = sys.argv[2]
126
+ # the score will be dumped into this output json file
127
+ out_path = sys.argv[3]
128
+
129
+ print("Read standard from %s" % standard_path)
130
+ print("Read user submit file from %s" % submit_path)
131
+
132
+ try:
133
+ # read ground-truth
134
+ reference = read_reference(standard_path)
135
+
136
+ # read predictions
137
+ k = 10
138
+ predictions = read_submission(submit_path, reference, k)
139
+
140
+ # compute score for each text
141
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
142
+ for qid in reference.keys():
143
+ ground_truth_ids = set(reference[qid])
144
+ top10_pred_ids = predictions[qid]
145
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
146
+ r1_stat += 1
147
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
148
+ r5_stat += 1
149
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
150
+ r10_stat += 1
151
+ # the higher score, the better
152
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
153
+ report_score(r1, r5, r10, out_path)
154
+ print("The evaluation finished successfully.")
155
+ except Exception as e:
156
+ report_error_msg(e.args[0], e.args[0], out_path)
157
+ print("The evaluation failed: {}".format(e.args[0]))
Model/CLIP/cn_clip/eval/extract_features.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script extracts image and text features for evaluation. (with single-GPU)
4
+ '''
5
+
6
+ import os
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+ import json
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from cn_clip.clip.model import convert_weights, CLIP
16
+ from cn_clip.training.main import convert_models_to_fp32
17
+ from cn_clip.eval.data import get_eval_img_dataset, get_eval_txt_dataset
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ '--extract-image-feats',
23
+ action="store_true",
24
+ default=False,
25
+ help="Whether to extract image features."
26
+ )
27
+ parser.add_argument(
28
+ '--extract-text-feats',
29
+ action="store_true",
30
+ default=False,
31
+ help="Whether to extract text features."
32
+ )
33
+ parser.add_argument(
34
+ '--image-data',
35
+ type=str,
36
+ default="../Multimodal_Retrieval/lmdb/test/imgs",
37
+ help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings."
38
+ )
39
+ parser.add_argument(
40
+ '--text-data',
41
+ type=str,
42
+ default="../Multimodal_Retrieval/test_texts.jsonl",
43
+ help="If --extract-text-feats is True, specify the path of input text Jsonl file."
44
+ )
45
+ parser.add_argument(
46
+ '--image-feat-output-path',
47
+ type=str,
48
+ default=None,
49
+ help="If --extract-image-feats is True, specify the path of output image features."
50
+ )
51
+ parser.add_argument(
52
+ '--text-feat-output-path',
53
+ type=str,
54
+ default=None,
55
+ help="If --extract-image-feats is True, specify the path of output text features."
56
+ )
57
+ parser.add_argument(
58
+ "--img-batch-size", type=int, default=64, help="Image batch size."
59
+ )
60
+ parser.add_argument(
61
+ "--text-batch-size", type=int, default=64, help="Text batch size."
62
+ )
63
+ parser.add_argument(
64
+ "--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)."
65
+ )
66
+ parser.add_argument(
67
+ "--resume",
68
+ default=None,
69
+ type=str,
70
+ help="path to latest checkpoint (default: none)",
71
+ )
72
+ parser.add_argument(
73
+ "--precision",
74
+ choices=["amp", "fp16", "fp32"],
75
+ default="amp",
76
+ help="Floating point precition."
77
+ )
78
+ parser.add_argument(
79
+ "--vision-model",
80
+ choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
81
+ default="ViT-B-16",
82
+ help="Name of the vision backbone to use.",
83
+ )
84
+ parser.add_argument(
85
+ "--text-model",
86
+ choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
87
+ default="RoBERTa-wwm-ext-base-chinese",
88
+ help="Name of the text backbone to use.",
89
+ )
90
+ parser.add_argument(
91
+ "--debug",
92
+ default=False,
93
+ action="store_true",
94
+ help="If true, more information is logged."
95
+ )
96
+ args = parser.parse_args()
97
+
98
+ return args
99
+
100
+
101
+ if __name__ == "__main__":
102
+ args = parse_args()
103
+
104
+ assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!"
105
+
106
+ # Log params.
107
+ print("Params:")
108
+ for name in sorted(vars(args)):
109
+ val = getattr(args, name)
110
+ print(f" {name}: {val}")
111
+
112
+ args.gpu = 0
113
+ torch.cuda.set_device(args.gpu)
114
+
115
+ # Initialize the model.
116
+ vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
117
+ print('Loading vision model config from', vision_model_config_file)
118
+ assert os.path.exists(vision_model_config_file)
119
+
120
+ text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
121
+ print('Loading text model config from', text_model_config_file)
122
+ assert os.path.exists(text_model_config_file)
123
+
124
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
125
+ model_info = json.load(fv)
126
+ if isinstance(model_info['vision_layers'], str):
127
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
128
+ for k, v in json.load(ft).items():
129
+ model_info[k] = v
130
+
131
+ model = CLIP(**model_info)
132
+ convert_weights(model)
133
+
134
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
135
+ if args.precision == "amp" or args.precision == "fp32":
136
+ convert_models_to_fp32(model)
137
+ model.cuda(args.gpu)
138
+ if args.precision == "fp16":
139
+ convert_weights(model)
140
+
141
+ # Get data.
142
+ if args.extract_image_feats:
143
+ print("Preparing image inference dataset.")
144
+ img_data = get_eval_img_dataset(args)
145
+ if args.extract_text_feats:
146
+ print("Preparing text inference dataset.")
147
+ text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length)
148
+
149
+ # Resume from a checkpoint.
150
+ print("Begin to load model checkpoint from {}.".format(args.resume))
151
+ assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
152
+ # Map model to be loaded to specified single gpu.
153
+ loc = "cuda:{}".format(args.gpu)
154
+ checkpoint = torch.load(args.resume, map_location='cpu')
155
+ start_epoch = checkpoint["epoch"]
156
+ sd = checkpoint["state_dict"]
157
+ if next(iter(sd.items()))[0].startswith('module'):
158
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
159
+ model.load_state_dict(sd)
160
+ print(
161
+ f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
162
+ )
163
+
164
+ # Make inference for texts
165
+ if args.extract_text_feats:
166
+ print('Make inference for texts...')
167
+ if args.text_feat_output_path is None:
168
+ args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6])
169
+ write_cnt = 0
170
+ with open(args.text_feat_output_path, "w") as fout:
171
+ model.eval()
172
+ dataloader = text_data.dataloader
173
+ with torch.no_grad():
174
+ for batch in tqdm(dataloader):
175
+ text_ids, texts = batch
176
+ texts = texts.cuda(args.gpu, non_blocking=True)
177
+ text_features = model(None, texts)
178
+ text_features /= text_features.norm(dim=-1, keepdim=True)
179
+ for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()):
180
+ fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature})))
181
+ write_cnt += 1
182
+ print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path))
183
+
184
+ # Make inference for images
185
+ if args.extract_image_feats:
186
+ print('Make inference for images...')
187
+ if args.image_feat_output_path is None:
188
+ # by default, we store the image features under the same directory with the text features
189
+ args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs"))
190
+ write_cnt = 0
191
+ with open(args.image_feat_output_path, "w") as fout:
192
+ model.eval()
193
+ dataloader = img_data.dataloader
194
+ with torch.no_grad():
195
+ for batch in tqdm(dataloader):
196
+ image_ids, images = batch
197
+ images = images.cuda(args.gpu, non_blocking=True)
198
+ image_features = model(images, None)
199
+ image_features /= image_features.norm(dim=-1, keepdim=True)
200
+ for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()):
201
+ fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature})))
202
+ write_cnt += 1
203
+ print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path))
204
+
205
+ print("Done!")
Model/CLIP/cn_clip/eval/imagenet_zeroshot_templates.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script records the imagenet classnames and templates (both translated in Chinese)
4
+ used for zero-shot evaluation.
5
+
6
+ The original classnames and templates in English are derived from open_clip
7
+ (https://github.com/mlfoundations/open_clip/blob/main/src/training/imagenet_zeroshot_data.py)
8
+ The translated classnames and templates in Chinese are derived from wukong
9
+ (https://gitee.com/mindspore/models/tree/master/research/mm/wukong)
10
+ '''
11
+
12
+ imagenet_classnames = [ "丁鲷", "金鱼", "大白鲨", "虎鲨", "锤头鲨", "电鳐", "黄貂鱼", "公鸡", "母鸡", "鸵鸟",
13
+ "燕雀", "金翅雀", "家朱雀", "灯芯草雀", "靛蓝雀", "蓝鹀", "夜莺", "松鸦", "喜鹊", "山雀",
14
+ "河鸟", "鸢(猛禽)", "秃头鹰", "秃鹫", "大灰猫头鹰", "欧洲火蝾螈", "普通蝾螈", "水蜥", "斑点蝾螈", "蝾螈",
15
+ "牛蛙", "树蛙", "尾蛙", "红海龟", "皮革龟", "泥龟", "淡水龟", "箱龟", "带状壁虎", "普通鬣蜥",
16
+ "美国变色龙", "鞭尾蜥蜴", "飞龙科蜥蜴", "褶边蜥蜴", "鳄鱼蜥蜴", "毒蜥", "绿蜥蜴", "非洲变色龙", "科莫多蜥蜴", "非洲鳄",
17
+ "美国鳄鱼", "三角龙", "雷蛇", "环蛇", "希腊蛇", "绿蛇", "国王蛇", "袜带蛇", "水蛇", "藤蛇",
18
+ "夜蛇", "大蟒蛇", "岩石蟒蛇", "印度眼镜蛇", "绿曼巴", "海蛇", "角腹蛇", "菱纹响尾蛇", "角响尾蛇", "三叶虫",
19
+ "盲蜘蛛", "蝎子", "黑金花园蜘蛛", "谷仓蜘蛛", "花园蜘蛛", "黑寡妇蜘蛛", "狼蛛", "狼蜘蛛", "壁虱", "蜈蚣",
20
+ "黑松鸡", "松鸡", "披肩鸡", "草原鸡", "孔雀", "鹌鹑", "鹧鸪", "非洲灰鹦鹉", "金刚鹦鹉", "硫冠鹦鹉",
21
+ "短尾鹦鹉", "褐翅鸦鹃", "食蜂鸟;蜂虎", "犀鸟", "蜂鸟", "鹟䴕", "巨嘴鸟;大嘴鸟", "野鸭", "红胸秋沙鸭", "鹅",
22
+ "黑天鹅", "大象", "针鼹鼠", "鸭嘴兽", "沙袋鼠", "考拉", "袋熊", "水母", "海葵", "脑珊瑚",
23
+ "扁形虫扁虫", "线虫", "海螺", "蜗牛", "鼻涕虫", "海蛞蝓;海参", "石鳖", "鹦鹉螺", "珍宝蟹", "石蟹",
24
+ "招潮蟹", "帝王蟹", "美国龙虾", "大螯虾", "小龙虾", "寄居蟹", "等足目动物(明虾和螃蟹近亲)", "白鹳", "黑鹳", "鹭",
25
+ "火烈鸟", "小蓝鹭", "美国鹭", "麻鸦", "鹤", "秧鹤", "欧洲水鸡", "沼泽泥母鸡", "鸨", "红翻石鹬",
26
+ "红背鹬", "红脚鹬", "半蹼鹬", "蛎鹬", "鹈鹕", "国王企鹅", "信天翁", "灰鲸", "杀人鲸", "海牛",
27
+ "海狮", "吉娃娃", "日本狆犬", "马尔济斯犬", "狮子狗", "西施犬", "布莱尼姆猎犬", "巴比狗", "玩具犬", "罗得西亚长背猎狗",
28
+ "阿富汗猎犬", "巴吉度猎犬", "比格犬", "侦探犬", "蓝色快狗", "黑褐猎浣熊犬", "沃克猎犬", "英国猎狐犬", "美洲赤狗", "俄罗斯猎狼犬",
29
+ "爱尔兰猎狼犬", "意大利灰狗", "惠比特犬", "依比沙猎犬", "挪威猎犬", "奥达猎犬", "沙克犬", "苏格兰猎鹿犬", "威玛猎犬", "斯塔福德郡斗牛犬",
30
+ "美国斯塔福德郡梗", "贝德灵顿梗", "边境梗", "凯丽蓝梗", "爱尔兰梗", "诺福克梗", "诺维奇梗", "约克犬;约克夏梗犬", "刚毛猎狐梗", "莱克兰梗",
31
+ "锡利哈姆梗", "艾尔谷犬", "凯恩梗", "澳大利亚梗", "丹迪丁蒙梗", "波士顿梗", "迷你雪纳瑞犬", "巨型雪纳瑞犬", "标准雪纳瑞犬", "苏格兰梗犬",
32
+ "西藏梗", "丝毛梗", "爱尔兰软毛梗犬", "西高地白梗", "拉萨阿普索犬", "平毛寻回犬", "卷毛寻回犬", "金毛猎犬", "拉布拉多猎犬", "乞沙比克猎犬",
33
+ "德国短毛指示犬", "维兹拉犬", "英国塞特犬", "爱尔兰雪达犬", "戈登雪达犬", "布列塔尼犬猎犬", "黄毛", "英国史宾格犬", "威尔士史宾格犬", "可卡犬",
34
+ "萨塞克斯猎犬", "爱尔兰水猎犬", "哥威斯犬", "舒柏奇犬", "比利时牧羊犬", "马里努阿犬", "伯瑞犬", "凯尔皮犬", "匈牙利牧羊犬", "老英国牧羊犬",
35
+ "喜乐蒂牧羊犬", "牧羊犬", "边境牧羊犬", "法兰德斯牧牛狗", "罗特韦尔犬", "德国牧羊犬", "多伯曼犬", "鹿犬;迷你杜宾犬", "大瑞士山地犬", "伯恩山犬",
36
+ "阿策尔山犬", "恩特尔布赫山犬", "拳师狗", "斗牛獒", "藏獒", "法国斗牛犬", "大丹犬", "圣伯纳德狗", "爱斯基摩犬", "阿拉斯加雪橇犬",
37
+ "哈士奇", "达尔马提亚", "狮毛狗", "巴辛吉狗", "八哥犬", "莱昂贝格狗", "纽芬兰犬", "大白熊犬", "萨摩耶犬", "博美犬",
38
+ "松狮", "凯斯犬", "布鲁塞尔格林芬犬", "彭布洛克威尔士科基犬", "威尔士柯基犬", "玩具贵宾犬", "迷你贵宾犬", "标准贵宾犬", "墨西哥无毛犬", "灰狼",
39
+ "白狼", "红太狼", "狼", "澳洲野狗", "豺", "非洲猎犬", "鬣狗", "红狐狸", "沙狐", "北极狐狸",
40
+ "灰狐狸", "虎斑猫", "山猫", "波斯猫", "暹罗猫", "埃及猫", "美洲狮", "猞猁", "豹子", "雪豹",
41
+ "美洲虎", "狮子", "老虎", "猎豹", "棕熊", "美洲黑熊", "冰熊", "懒熊", "獴", "猫鼬",
42
+ "虎甲虫", "瓢虫", "土鳖虫", "天牛", "龟甲虫", "粪甲虫", "犀牛甲虫", "象甲", "苍蝇", "蜜蜂",
43
+ "蚂蚁", "蚱蜢", "蟋蟀", "竹节虫", "蟑螂", "螳螂", "蝉", "叶蝉", "草蜻蛉", "蜻蜓",
44
+ "豆娘", "优红蛱蝶", "小环蝴蝶", "君主蝴蝶", "菜粉蝶", "白蝴蝶", "灰蝶", "海星", "海胆", "海黄瓜;海参",
45
+ "野兔", "兔", "安哥拉兔", "仓鼠", "刺猬", "黑松鼠", "土拨鼠", "海狸", "豚鼠", "栗色马",
46
+ "斑马", "猪", "野猪", "疣猪", "河马", "牛", "水牛", "野牛", "公羊", "大角羊",
47
+ "山羊", "狷羚", "黑斑羚", "瞪羚", "阿拉伯单峰骆驼", "骆驼", "黄鼠狼", "水貂", "臭猫", "黑足鼬",
48
+ "水獭", "臭鼬", "獾", "犰狳", "树懒", "猩猩", "大猩猩", "黑猩猩", "长臂猿", "合趾猿长臂猿",
49
+ "长尾猴", "赤猴", "狒狒", "恒河猴", "白头叶猴", "疣猴", "长鼻猴", "狨(美洲产小型长尾猴)", "卷尾猴", "吼猴",
50
+ "伶猴", "蜘蛛猴", "松鼠猴", "马达加斯加环尾狐猴", "大狐猴", "印度大象", "非洲象", "小熊猫", "大熊猫", "杖鱼",
51
+ "鳗鱼", "银鲑", "三色刺蝶鱼", "海葵鱼", "鲟鱼", "雀鳝", "狮子鱼", "河豚", "算盘", "长袍",
52
+ "学位袍", "手风琴", "原声吉他", "航空母舰", "客机", "飞艇", "祭坛", "救护车", "水陆两用车", "模拟时钟",
53
+ "蜂房", "围裙", "垃圾桶", "攻击步枪", "背包", "面包店", "平衡木", "热气球", "圆珠笔", "创可贴",
54
+ "班卓琴", "栏杆", "杠铃", "理发师的椅子", "理发店", "牲口棚", "晴雨表", "圆筒", "园地小车", "棒球",
55
+ "篮球", "婴儿床", "巴松管", "游泳帽", "沐浴毛巾", "浴缸", "沙滩车", "灯塔", "烧杯", "熊皮高帽",
56
+ "啤酒瓶", "啤酒杯", "钟塔", "(小儿用的)围嘴", "串联自行车", "比基尼", "装订册", "双筒望远镜", "鸟舍", "船库",
57
+ "双人雪橇", "饰扣式领带", "阔边女帽", "书橱", "书店", "瓶盖", "弓箭", "蝴蝶结领结", "铜制牌位", "奶罩",
58
+ "防波堤", "铠甲", "扫帚", "桶", "扣环", "防弹背心", "动车", "肉铺", "出租车", "大锅",
59
+ "蜡烛", "大炮", "独木舟", "开瓶器", "开衫", "车镜", "旋转木马", "木匠的工具包", "纸箱", "车轮",
60
+ "取款机", "盒式录音带", "卡带播放器", "城堡", "双体船", "CD播放器", "大提琴", "移动电话", "铁链", "围栏",
61
+ "链甲", "电锯", "箱子", "梳妆台", "编钟", "中国橱柜", "圣诞袜", "教堂", "电影院", "切肉刀",
62
+ "悬崖屋", "斗篷", "木屐", "鸡尾酒调酒器", "咖啡杯", "咖啡壶", "螺旋结构(楼梯)", "组合锁", "电脑键盘", "糖果",
63
+ "集装箱船", "敞篷车", "瓶塞钻", "短号", "牛仔靴", "牛仔帽", "摇篮", "起重机", "头盔", "板条箱",
64
+ "小儿床", "砂锅", "槌球", "拐杖", "胸甲", "大坝", "书桌", "台式电脑", "有线电话", "尿布湿",
65
+ "数字时钟", "数字手表", "餐桌板", "抹布", "洗碗机", "盘式制动器", "码头", "狗拉雪橇", "圆顶", "门垫",
66
+ "钻井平台", "鼓", "鼓槌", "哑铃", "荷兰烤箱", "电风扇", "电吉他", "电力机车", "组合电视柜", "信封",
67
+ "浓缩咖啡机", "扑面粉", "女用长围巾", "文件", "消防船", "消防车", "火炉栏", "旗杆", "长笛", "折叠椅",
68
+ "橄榄球头盔", "叉车", "喷泉", "钢笔", "有四根帷柱的床", "运货车厢", "圆号", "煎锅", "裘皮大衣", "垃圾车",
69
+ "防毒面具", "汽油泵", "高脚杯", "卡丁车", "高尔夫球", "高尔夫球车", "狭长小船", "锣", "礼服", "钢琴",
70
+ "温室", "散热器格栅", "杂货店", "断头台", "小发夹", "头发喷雾", "半履带装甲车", "锤子", "大篮子", "手摇鼓风机",
71
+ "手提电脑", "手帕", "硬盘", "口琴", "竖琴", "收割机", "斧头", "手枪皮套", "家庭影院", "蜂窝",
72
+ "钩爪", "衬裙", "单杠", "马车", "沙漏", "iPod", "熨斗", "南瓜灯笼", "牛仔裤", "吉普车",
73
+ "T恤衫", "拼图", "人力车", "操纵杆", "和服", "护膝", "蝴蝶结", "大褂", "长柄勺", "灯罩",
74
+ "笔记本电脑", "割草机", "镜头盖", "开信刀;拆信刀", "图书馆", "救生艇", "点火器", "豪华轿车", "远洋班轮", "唇膏",
75
+ "平底便鞋", "洗剂", "扬声器", "放大镜", "锯木厂", "磁罗盘", "邮袋", "信箱", "女游泳衣", "有肩带浴衣",
76
+ "窨井盖", "沙球(一种打击乐器)", "马林巴木琴", "面膜", "火柴", "花柱", "迷宫", "量杯", "药箱", "巨石",
77
+ "麦克风", "微波炉", "军装", "奶桶", "迷你巴士", "迷你裙", "面包车;小型货车", "导弹", "连指手套", "搅拌钵",
78
+ "活动房屋(由汽车拖拉的)", "福特T型车", "调制解调器;光猫", "修道院", "显示器", "电瓶车", "砂浆", "学士", "清真寺", "蚊帐",
79
+ "摩托车", "山地自行车", "登山帐", "鼠标", "捕鼠器", "搬家货车", "动物的口套", "金属钉子", "颈托", "项链",
80
+ "乳头(瓶)", "平板电脑", "方尖碑", "双簧管", "小鹅笛;球形笛(管身椭圆形)", "里程表", "滤油器", "风琴", "示波器", "罩裙",
81
+ "牛车", "氧气面罩", "包装", "船桨", "明轮", "挂锁", "画笔", "睡衣", "宫殿", "排箫",
82
+ "纸巾", "降落伞", "双杠", "公园长椅", "停车收费表", "客车", "露台", "付费电话", "基座", "铅笔盒",
83
+ "卷笔刀", "香水(瓶)", "培养皿", "复印机", "拨弦片", "尖顶头盔", "用尖板条连成的尖桩篱栅", "皮卡", "桥墩", "存钱罐",
84
+ "药瓶", "枕头", "乒乓球", "风车", "海盗船", "水罐", "木工刨", "天文馆", "塑料袋", "板架",
85
+ "犁型铲雪机", "手压皮碗泵", "宝丽来相机", "电线杆", "警车", "雨披", "台球桌", "充气饮料瓶", "花盆", "陶工旋盘",
86
+ "电钻", "祈祷垫", "打印机", "监狱", "炮弹", "投影仪", "冰球", "沙包", "小钱袋;手袋", "羽管笔",
87
+ "被子", "赛车", "球拍", "散热器", "收音机", "射电望远镜", "雨桶", "休闲车", "卷轴", "反射式照相机",
88
+ "冰箱", "遥控器", "餐厅", "左轮手枪", "步枪", "摇椅", "电转烤肉架", "橡皮", "橄榄球", "直尺",
89
+ "跑步鞋", "保险柜", "安全别针", "盐瓶(调味用)", "凉鞋", "纱笼", "萨克斯管", "剑鞘", "秤", "校车",
90
+ "帆船", "记分牌", "屏幕", "螺丝", "螺丝刀", "安全带", "缝纫机", "盾牌", "皮鞋店", "障子",
91
+ "购物篮", "购物车", "铁锹", "浴帽", "浴帘", "滑雪板", "滑雪面罩", "睡袋", "滑尺", "滑动门",
92
+ "角子老虎机", "潜水通气管", "摩托雪橇;雪地机动车", "扫雪机", "皂液器", "足球", "袜子", "碟式太阳能", "宽边帽", "汤碗",
93
+ "空格键", "空间加热器", "航天飞机", "锅铲;做饭的铲子", "快艇", "蜘蛛网", "纺锤;手纺用的绕线杆", "跑车", "聚光灯", "舞台",
94
+ "蒸汽机车", "钢拱桥", "钢滚筒", "听诊器", "女用披肩", "石头墙", "秒表", "火炉", "过滤器", "有轨电车",
95
+ "担架", "沙发床", "佛塔", "潜艇", "套装", "日晷", "太阳镜", "太阳镜", "防晒霜", "悬索桥",
96
+ "拖把", "运动衫", "游泳裤", "秋千", "开关", "注射器;吸管", "台灯", "坦克", "录音机", "茶壶",
97
+ "泰迪", "电视", "网球;打网球的球", "茅草", "幕布", "顶针", "打谷机;脱粒机", "宝座", "瓦屋顶", "烤面包机",
98
+ "烟草店", "马桶", "火炬", "图腾柱", "拖车;牵引车", "玩具店", "拖拉机", "半挂汽车", "托盘", "风衣",
99
+ "三轮车", "三体船", "三脚架", "凯旋门", "无轨电车", "长号", "浴盆", "旋转式栅门", "打字机键盘", "伞",
100
+ "独轮车", "直立式钢琴", "吸尘器", "花瓶;装饰瓶", "拱顶", "天鹅绒", "自动售货机", "法衣;祭衣;祭服", "高架桥", "小提琴",
101
+ "排球", "松饼机", "挂钟", "钱包;钱夹", "衣柜衣橱", "军用飞机", "洗脸盆", "洗衣机", "水瓶", "水壶",
102
+ "水塔", "威士忌壶", "哨子", "假发", "纱窗", "百叶窗", "温莎领带", "葡萄酒瓶", "飞机翅膀", "炒菜锅",
103
+ "木勺子;木头勺子", "毛织品", "原木栅栏", "沉船", "双桅船", "蒙古包", "网站;网页", "漫画", "纵横字谜", "路标",
104
+ "交通信号灯", "防尘罩", "菜单", "盘子", "墨西哥鳄梨酱;墨西哥牛油果酱", "清炖肉汤", "火锅", "乳脂蛋糕;英国甜点", "冰淇淋", "冰棍;雪糕",
105
+ "法式面包", "百吉饼", "椒盐脆饼", "芝士汉堡", "热狗", "土豆泥", "结球甘蓝", "西兰花;绿菜花", "菜花;花椰菜", "西葫芦",
106
+ "金丝瓜;意面南瓜;面条瓜", "绿色小南瓜;青南瓜", "南瓜", "黄瓜", "洋蓟;球蓟", "甜椒", "刺棘蓟", "蘑菇", "绿苹果", "草莓",
107
+ "橘子", "柠檬", "无花果", "菠萝", "香蕉", "菠萝蜜", "番荔枝", "石榴", "干草", "培根蛋酱意大利面",
108
+ "巧克力酱", "生面;面团", "瑞士肉包", "披萨", "馅饼", "卷饼", "红葡萄酒", "意式浓缩咖啡", "杯子", "蛋酒",
109
+ "高山", "泡泡", "悬崖", "珊瑚礁", "间歇泉;间断喷发的温泉", "湖边", "岬角;深入海中的狭长高地", "沙洲", "沙滩", "峡谷",
110
+ "火山", "棒球运动员", "新郎", "潜水员", "油菜", "雏菊", "黄色杓兰", "玉米", "橡子", "玫瑰果",
111
+ "七叶树果实", "珊瑚菌", "木耳", "鹿花菌", "臭角菇", "地星", "多叶奇果菌", "牛肝菌", "玉米棒子", "卫生纸"]
112
+
113
+ openai_imagenet_template = [
114
+ lambda c: f'{c}的照片。',
115
+ lambda c: f'质量差的{c}的照片。',
116
+ lambda c: f'许多{c}的照片。',
117
+ lambda c: f'{c}的雕塑。',
118
+ lambda c: f'难以看到{c}的照片。',
119
+ lambda c: f'{c}的低分辨率照片。',
120
+ lambda c: f'{c}的渲染。',
121
+ lambda c: f'涂鸦{c}。',
122
+ lambda c: f'{c}的糟糕照片。',
123
+ lambda c: f'{c}的裁剪照片。',
124
+ lambda c: f'{c}的纹身。',
125
+ lambda c: f'{c}的刺绣照片。',
126
+ lambda c: f'很难看到{c}的照片。',
127
+ lambda c: f'{c}的明亮照片。',
128
+ lambda c: f'一张干净的{c}的照片。',
129
+ lambda c: f'一张包含{c}的照片。',
130
+ lambda c: f'{c}的深色照片。',
131
+ lambda c: f'{c}的手绘画。',
132
+ lambda c: f'���的{c}的照片。',
133
+ lambda c: f'不自然的{c}的照片。',
134
+ lambda c: f'一张酷的{c}的照片。',
135
+ lambda c: f'{c}的特写照片。',
136
+ lambda c: f'{c}的黑白照片。',
137
+ lambda c: f'一幅{c}的画。',
138
+ lambda c: f'一幅{c}的绘画。',
139
+ lambda c: f'一张{c}的像素照片。',
140
+ lambda c: f'{c}的雕像。',
141
+ lambda c: f'一张{c}的明亮照片。',
142
+ lambda c: f'{c}的裁剪照片。',
143
+ lambda c: f'人造的{c}的照片。',
144
+ lambda c: f'一张关于{c}的照片。',
145
+ lambda c: f'损坏的{c}的jpeg照片。',
146
+ lambda c: f'{c}的模糊照片。',
147
+ lambda c: f'{c}的相片。',
148
+ lambda c: f'一张{c}的好照片。',
149
+ lambda c: f'{c}的渲染照。',
150
+ lambda c: f'视频游戏中的{c}。',
151
+ lambda c: f'一张{c}的照片。',
152
+ lambda c: f'{c}的涂鸦。',
153
+ lambda c: f'{c}的近距离照片。',
154
+ lambda c: f'{c}的折纸。',
155
+ lambda c: f'{c}在视频游戏中。',
156
+ lambda c: f'{c}的草图。',
157
+ lambda c: f'{c}的涂鸦照。',
158
+ lambda c: f'{c}的折纸形状。',
159
+ lambda c: f'低分辨率的{c}的照片。',
160
+ lambda c: f'玩具{c}。',
161
+ lambda c: f'{c}的副本。',
162
+ lambda c: f'{c}的干净的照片。',
163
+ lambda c: f'一张大{c}的照片。',
164
+ lambda c: f'{c}的重现。',
165
+ lambda c: f'一张漂亮的{c}的照片。',
166
+ lambda c: f'一张奇怪的{c}的照片。',
167
+ lambda c: f'模糊的{c}的照片。',
168
+ lambda c: f'卡通{c}。',
169
+ lambda c: f'{c}的艺术作品。',
170
+ lambda c: f'{c}的素描。',
171
+ lambda c: f'刺绣{c}。',
172
+ lambda c: f'{c}的像素照。',
173
+ lambda c: f'{c}的拍照。',
174
+ lambda c: f'{c}的损坏的照片。',
175
+ lambda c: f'高质量的{c}的照片。',
176
+ lambda c: f'毛绒玩具{c}。',
177
+ lambda c: f'漂亮的{c}的照片。',
178
+ lambda c: f'小{c}的照片。',
179
+ lambda c: f'照片是奇怪的{c}。',
180
+ lambda c: f'漫画{c}。',
181
+ lambda c: f'{c}的艺术照。',
182
+ lambda c: f'{c}的图形。',
183
+ lambda c: f'大{c}的照片。',
184
+ lambda c: f'黑白的{c}的照片。',
185
+ lambda c: f'{c}毛绒玩具。',
186
+ lambda c: f'一张{c}的深色照片。',
187
+ lambda c: f'{c}的摄影图。',
188
+ lambda c: f'{c}的涂鸦照。',
189
+ lambda c: f'玩具形状的{c}。',
190
+ lambda c: f'拍了{c}的照片。',
191
+ lambda c: f'酷酷的{c}的照片。',
192
+ lambda c: f'照片里的小{c}。',
193
+ lambda c: f'{c}的刺青。',
194
+ ]
Model/CLIP/cn_clip/eval/make_topk_predictions.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs text-to-image prediction file for evaluation.
4
+ '''
5
+
6
+ import argparse
7
+ import numpy
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--image-feats',
18
+ type=str,
19
+ required=True,
20
+ help="Specify the path of image features."
21
+ )
22
+ parser.add_argument(
23
+ '--text-feats',
24
+ type=str,
25
+ required=True,
26
+ help="Specify the path of text features."
27
+ )
28
+ parser.add_argument(
29
+ '--top-k',
30
+ type=int,
31
+ default=10,
32
+ help="Specify the k value of top-k predictions."
33
+ )
34
+ parser.add_argument(
35
+ '--eval-batch-size',
36
+ type=int,
37
+ default=32768,
38
+ help="Specify the image-side batch size when computing the inner products, default to 8192"
39
+ )
40
+ parser.add_argument(
41
+ '--output',
42
+ type=str,
43
+ required=True,
44
+ help="Specify the output jsonl prediction filepath."
45
+ )
46
+ return parser.parse_args()
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+
51
+ # Log params.
52
+ print("Params:")
53
+ for name in sorted(vars(args)):
54
+ val = getattr(args, name)
55
+ print(f" {name}: {val}")
56
+
57
+ print("Begin to load image features...")
58
+ image_ids = []
59
+ image_feats = []
60
+ with open(args.image_feats, "r") as fin:
61
+ for line in tqdm(fin):
62
+ obj = json.loads(line.strip())
63
+ image_ids.append(obj['image_id'])
64
+ image_feats.append(obj['feature'])
65
+ image_feats_array = np.array(image_feats, dtype=np.float32)
66
+ print("Finished loading image features.")
67
+
68
+ print("Begin to compute top-{} predictions for texts...".format(args.top_k))
69
+ with open(args.output, "w") as fout:
70
+ with open(args.text_feats, "r") as fin:
71
+ for line in tqdm(fin):
72
+ obj = json.loads(line.strip())
73
+ text_id = obj['text_id']
74
+ text_feat = obj['feature']
75
+ score_tuples = []
76
+ text_feat_tensor = torch.tensor([text_feat], dtype=torch.float).cuda() # [1, feature_dim]
77
+ idx = 0
78
+ while idx < len(image_ids):
79
+ img_feats_tensor = torch.from_numpy(image_feats_array[idx : min(idx + args.eval_batch_size, len(image_ids))]).cuda() # [batch_size, feature_dim]
80
+ batch_scores = text_feat_tensor @ img_feats_tensor.t() # [1, batch_size]
81
+ for image_id, score in zip(image_ids[idx : min(idx + args.eval_batch_size, len(image_ids))], batch_scores.squeeze(0).tolist()):
82
+ score_tuples.append((image_id, score))
83
+ idx += args.eval_batch_size
84
+ top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
85
+ fout.write("{}\n".format(json.dumps({"text_id": text_id, "image_ids": [entry[0] for entry in top_k_predictions]})))
86
+
87
+ print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
88
+ print("Done!")
Model/CLIP/cn_clip/eval/make_topk_predictions_tr.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs image-to-text retrieval prediction file for evaluation.
4
+ '''
5
+
6
+ import argparse
7
+ import numpy
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--image-feats',
18
+ type=str,
19
+ required=True,
20
+ help="Specify the path of image features."
21
+ )
22
+ parser.add_argument(
23
+ '--text-feats',
24
+ type=str,
25
+ required=True,
26
+ help="Specify the path of text features."
27
+ )
28
+ parser.add_argument(
29
+ '--top-k',
30
+ type=int,
31
+ default=10,
32
+ help="Specify the k value of top-k predictions."
33
+ )
34
+ parser.add_argument(
35
+ '--eval-batch-size',
36
+ type=int,
37
+ default=32768,
38
+ help="Specify the image-side batch size when computing the inner products, default to 8192"
39
+ )
40
+ parser.add_argument(
41
+ '--output',
42
+ type=str,
43
+ required=True,
44
+ help="Specify the output jsonl prediction filepath."
45
+ )
46
+ return parser.parse_args()
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+
51
+ # Log params.
52
+ print("Params:")
53
+ for name in sorted(vars(args)):
54
+ val = getattr(args, name)
55
+ print(f" {name}: {val}")
56
+
57
+ print("Begin to load text features...")
58
+ text_ids = []
59
+ text_feats = []
60
+ with open(args.text_feats, "r") as fin:
61
+ for line in tqdm(fin):
62
+ obj = json.loads(line.strip())
63
+ text_ids.append(obj['text_id'])
64
+ text_feats.append(obj['feature'])
65
+ text_feats_array = np.array(text_feats, dtype=np.float32)
66
+ print("Finished loading text features.")
67
+
68
+ print("Begin to compute top-{} predictions for images...".format(args.top_k))
69
+ with open(args.output, "w") as fout:
70
+ with open(args.image_feats, "r") as fin:
71
+ for line in tqdm(fin):
72
+ obj = json.loads(line.strip())
73
+ image_id = obj['image_id']
74
+ image_feat = obj['feature']
75
+ score_tuples = []
76
+ image_feat_tensor = torch.tensor([image_feat], dtype=torch.float).cuda() # [1, feature_dim]
77
+ idx = 0
78
+ while idx < len(text_ids):
79
+ text_feats_tensor = torch.from_numpy(text_feats_array[idx : min(idx + args.eval_batch_size, len(text_ids))]).cuda() # [batch_size, feature_dim]
80
+ batch_scores = image_feat_tensor @ text_feats_tensor.t() # [1, batch_size]
81
+ for text_id, score in zip(text_ids[idx : min(idx + args.eval_batch_size, len(text_ids))], batch_scores.squeeze(0).tolist()):
82
+ score_tuples.append((text_id, score))
83
+ idx += args.eval_batch_size
84
+ top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
85
+ fout.write("{}\n".format(json.dumps({"image_id": image_id, "text_ids": [entry[0] for entry in top_k_predictions]})))
86
+
87
+ print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
88
+ print("Done!")
Model/CLIP/cn_clip/eval/transform_ir_annotation_to_tr.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from tqdm import tqdm
3
+ import argparse
4
+ import json
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ '--input',
10
+ type=str,
11
+ required=True,
12
+ help="Input path of text-to-image Jsonl annotation file."
13
+ )
14
+ return parser.parse_args()
15
+
16
+ if __name__ == "__main__":
17
+ args = parse_args()
18
+
19
+ t2i_record = dict()
20
+
21
+ with open(args.input, "r") as fin:
22
+ for line in tqdm(fin):
23
+ obj = json.loads(line.strip())
24
+ text_id = obj['text_id']
25
+ image_ids = obj['image_ids']
26
+ for image_id in image_ids:
27
+ if image_id not in t2i_record:
28
+ t2i_record[image_id] = []
29
+ t2i_record[image_id].append(text_id)
30
+
31
+ with open(args.input.replace(".jsonl", "") + ".tr.jsonl", "w") as fout:
32
+ for image_id, text_ids in t2i_record.items():
33
+ out_obj = {"image_id": image_id, "text_ids": text_ids}
34
+ fout.write("{}\n".format(json.dumps(out_obj)))
35
+
36
+ print("Done!")
Model/CLIP/cn_clip/eval/zeroshot_evaluation.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script performs zero-shot evaluation on ImageNet-1K. (with single-GPU)
4
+ '''
5
+
6
+ import os
7
+ import argparse
8
+ from pathlib import Path
9
+ import json
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+
14
+ from cn_clip.clip.model import convert_weights, CLIP
15
+ from cn_clip.clip import tokenize
16
+ from cn_clip.training.main import convert_models_to_fp32
17
+ from cn_clip.clip.utils import image_transform
18
+ from cn_clip.eval.data import get_imagenet_dataset, _preprocess_text
19
+ from cn_clip.eval.imagenet_zeroshot_templates import imagenet_classnames, openai_imagenet_template
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--vision-model",
26
+ choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
27
+ default="ViT-B-16",
28
+ help="Name of the vision backbone to use.",
29
+ )
30
+ parser.add_argument(
31
+ "--text-model",
32
+ choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
33
+ default="RoBERTa-wwm-ext-base-chinese",
34
+ help="Name of the text backbone to use.",
35
+ )
36
+ parser.add_argument(
37
+ "--precision",
38
+ choices=["amp", "fp16", "fp32"],
39
+ default="amp",
40
+ help="Floating point precition."
41
+ )
42
+ parser.add_argument(
43
+ "--imagenet-val",
44
+ type=str,
45
+ required=True,
46
+ help="Path to imagenet val set for conducting zero shot evaluation.",
47
+ )
48
+ parser.add_argument(
49
+ "--img-batch-size", type=int, default=64, help="Image batch size."
50
+ )
51
+ parser.add_argument(
52
+ "--context-length",
53
+ type=int,
54
+ default=32,
55
+ help="The maximum length of input text (include [CLS] & [SEP] tokens)."
56
+ )
57
+ parser.add_argument(
58
+ "--resume",
59
+ default=None,
60
+ type=str,
61
+ help="path to latest checkpoint (default: none)",
62
+ )
63
+ parser.add_argument(
64
+ "--num-workers", type=int, default=4, help="Number of workers for ImageNet dataset."
65
+ )
66
+ args = parser.parse_args()
67
+
68
+ return args
69
+
70
+
71
+ def zero_shot_classifier(model, classnames, templates, args):
72
+ with torch.no_grad():
73
+ zeroshot_weights = []
74
+ for classname in tqdm(classnames):
75
+ texts = [_preprocess_text(template(classname)) for template in templates] #format with class
76
+ texts = tokenize(texts, context_length=args.context_length).to(args.gpu) #tokenize
77
+ class_embeddings = model(None, texts)
78
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
79
+ class_embedding = class_embeddings.mean(dim=0)
80
+ class_embedding /= class_embedding.norm()
81
+ zeroshot_weights.append(class_embedding)
82
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
83
+ return zeroshot_weights
84
+
85
+
86
+ def accuracy(output, target, topk=(1,)):
87
+ pred = output.topk(max(topk), 1, True, True)[1].t()
88
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
89
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
90
+
91
+
92
+ def run(model, classifier, dataloader, args):
93
+ with torch.no_grad():
94
+ top1, top5, n = 0., 0., 0.
95
+ for images, target in tqdm(dataloader):
96
+ images = images.to(args.gpu)
97
+ target = target.to(args.gpu)
98
+
99
+ # predict
100
+ image_features = model(images, None)
101
+ image_features /= image_features.norm(dim=-1, keepdim=True)
102
+ logits = 100. * image_features @ classifier
103
+
104
+ # measure accuracy
105
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
106
+ top1 += acc1
107
+ top5 += acc5
108
+ n += images.size(0)
109
+
110
+ top1 = (top1 / n)
111
+ top5 = (top5 / n)
112
+ return top1, top5
113
+
114
+
115
+ if __name__ == "__main__":
116
+ args = parse_args()
117
+
118
+ # Log params.
119
+ print("Params:")
120
+ for name in sorted(vars(args)):
121
+ val = getattr(args, name)
122
+ print(f" {name}: {val}")
123
+
124
+ args.gpu = 0
125
+ torch.cuda.set_device(args.gpu)
126
+
127
+ # Initialize the model.
128
+ vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
129
+ print('Loading vision model config from', vision_model_config_file)
130
+ assert os.path.exists(vision_model_config_file)
131
+
132
+ text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
133
+ print('Loading text model config from', text_model_config_file)
134
+ assert os.path.exists(text_model_config_file)
135
+
136
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
137
+ model_info = json.load(fv)
138
+ if isinstance(model_info['vision_layers'], str):
139
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
140
+ for k, v in json.load(ft).items():
141
+ model_info[k] = v
142
+
143
+ model = CLIP(**model_info)
144
+ convert_weights(model)
145
+
146
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
147
+ if args.precision == "amp" or args.precision == "fp32":
148
+ convert_models_to_fp32(model)
149
+ model.cuda(args.gpu)
150
+ if args.precision == "fp16":
151
+ convert_weights(model)
152
+
153
+ # Get imagenet eval data.
154
+ print("Preparing imagenet val dataset.")
155
+ data = {}
156
+ data["imagenet-val"] = get_imagenet_dataset(args, image_transform(model_info['image_resolution']), "val")
157
+
158
+ # Resume from a checkpoint.
159
+ print("Begin to load model checkpoint from {}.".format(args.resume))
160
+ assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
161
+ # Map model to be loaded to specified single gpu.
162
+ loc = "cuda:{}".format(args.gpu)
163
+ checkpoint = torch.load(args.resume, map_location='cpu')
164
+ start_epoch = checkpoint["epoch"]
165
+ sd = checkpoint["state_dict"]
166
+ if next(iter(sd.items()))[0].startswith('module'):
167
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
168
+ model.load_state_dict(sd)
169
+ print(
170
+ f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
171
+ )
172
+
173
+ # Compute ensembled class embeddings
174
+ print('Building zero-shot classifier')
175
+
176
+ model.eval()
177
+
178
+ classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
179
+
180
+ # Make inference and evaluation
181
+ print('Using classifier')
182
+ results = {}
183
+ top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
184
+ results['imagenet-zeroshot-val-top1'] = top1
185
+ results['imagenet-zeroshot-val-top5'] = top5
186
+
187
+ print('Result:')
188
+ print(", ".join(["{}: {}".format(k, v) for k, v in results.items()]))
189
+ print('Finished.')
Model/CLIP/cn_clip/preprocess/__init__.py ADDED
File without changes