update all file v1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Model/AttDes/__init__.py +4 -0
- Model/AttDes/__pycache__/__init__.cpython-38.pyc +0 -0
- Model/AttDes/dataset/data_loader.py +170 -0
- Model/AttDes/models/AttDes.py +17 -0
- Model/AttDes/models/Chinese_tokenizer.pth +3 -0
- Model/AttDes/models/__init__.py +1 -0
- Model/AttDes/models/__pycache__/__init__.cpython-38.pyc +0 -0
- Model/AttDes/models/language_model/bert.py +50 -0
- Model/AttDes/models/prefixLM.py +108 -0
- Model/AttDes/models/resblock.py +353 -0
- Model/AttDes/models/tokenizer.py +92 -0
- Model/AttDes/models/transformer.py +291 -0
- Model/AttDes/models/visual_model/Chinese_tokenizer.pth +3 -0
- Model/AttDes/models/visual_model/backbone.py +121 -0
- Model/AttDes/models/visual_model/position_encoding.py +89 -0
- Model/AttDes/validate_local.py +399 -0
- Model/AttDes/validate_local_gennerate.py +332 -0
- Model/CLIP/cn_clip/__init__.py +0 -0
- Model/CLIP/cn_clip/__pycache__/__init__.cpython-38.pyc +0 -0
- Model/CLIP/cn_clip/clip/__init__.py +5 -0
- Model/CLIP/cn_clip/clip/__pycache__/__init__.cpython-38.pyc +0 -0
- Model/CLIP/cn_clip/clip/__pycache__/bert_tokenizer.cpython-38.pyc +0 -0
- Model/CLIP/cn_clip/clip/__pycache__/utils.cpython-38.pyc +0 -0
- Model/CLIP/cn_clip/clip/bert_tokenizer.py +436 -0
- Model/CLIP/cn_clip/clip/configuration_bert.py +84 -0
- Model/CLIP/cn_clip/clip/model.py +504 -0
- Model/CLIP/cn_clip/clip/model_configs/RBT3-chinese.json +13 -0
- Model/CLIP/cn_clip/clip/model_configs/RN50.json +7 -0
- Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json +13 -0
- Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-large-chinese.json +13 -0
- Model/CLIP/cn_clip/clip/model_configs/ViT-B-16.json +7 -0
- Model/CLIP/cn_clip/clip/model_configs/ViT-B-32.json +7 -0
- Model/CLIP/cn_clip/clip/model_configs/ViT-H-14.json +8 -0
- Model/CLIP/cn_clip/clip/model_configs/ViT-L-14-336.json +7 -0
- Model/CLIP/cn_clip/clip/model_configs/ViT-L-14.json +7 -0
- Model/CLIP/cn_clip/clip/model_configs/for_learn.py +16 -0
- Model/CLIP/cn_clip/clip/modeling_bert.py +460 -0
- Model/CLIP/cn_clip/clip/utils.py +196 -0
- Model/CLIP/cn_clip/clip/vocab.txt +0 -0
- Model/CLIP/cn_clip/eval/__init__.py +0 -0
- Model/CLIP/cn_clip/eval/data.py +167 -0
- Model/CLIP/cn_clip/eval/evaluation.py +157 -0
- Model/CLIP/cn_clip/eval/evaluation_tr.py +157 -0
- Model/CLIP/cn_clip/eval/extract_features.py +205 -0
- Model/CLIP/cn_clip/eval/imagenet_zeroshot_templates.py +194 -0
- Model/CLIP/cn_clip/eval/make_topk_predictions.py +88 -0
- Model/CLIP/cn_clip/eval/make_topk_predictions_tr.py +88 -0
- Model/CLIP/cn_clip/eval/transform_ir_annotation_to_tr.py +36 -0
- Model/CLIP/cn_clip/eval/zeroshot_evaluation.py +189 -0
- Model/CLIP/cn_clip/preprocess/__init__.py +0 -0
Model/AttDes/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import Model.AttDes.models
|
2 |
+
import Model.AttDes.dataset
|
3 |
+
|
4 |
+
|
Model/AttDes/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (208 Bytes). View file
|
|
Model/AttDes/dataset/data_loader.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import sys
|
7 |
+
import json
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import os.path as osp
|
15 |
+
import scipy.io as sio
|
16 |
+
import torch.utils.data as data
|
17 |
+
from PIL import Image
|
18 |
+
import matplotlib.image as mping
|
19 |
+
import torchvision.transforms as transforms
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
23 |
+
|
24 |
+
|
25 |
+
def get_data_from_csv(path):
|
26 |
+
data_csv = pd.read_csv(path, encoding='utf-8')
|
27 |
+
# print(data_csv)
|
28 |
+
pic_id_list = data_csv['pic_id'].values
|
29 |
+
seg_id_list = data_csv['seg_id'].values
|
30 |
+
object_list = data_csv['object'].values
|
31 |
+
segment_list = data_csv['segment'].values
|
32 |
+
adj_list = data_csv['adj'].values
|
33 |
+
des_list = data_csv['des'].values
|
34 |
+
|
35 |
+
return pic_id_list, seg_id_list, object_list, segment_list, adj_list, des_list
|
36 |
+
|
37 |
+
class AttDesDataset(data.Dataset):
|
38 |
+
|
39 |
+
def __init__(self, data_root, dataset_name, img_root, dataset_split='train', transform=None,
|
40 |
+
bert_model='bert-base-chinese',
|
41 |
+
des_len=256, obj_len=8, tgt_len=32
|
42 |
+
):
|
43 |
+
self.images = []
|
44 |
+
self.data_root = data_root
|
45 |
+
self.dataset_name = dataset_name
|
46 |
+
self.transform = transform
|
47 |
+
self.img_root = img_root
|
48 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model)
|
49 |
+
self.des_len = des_len
|
50 |
+
self.obj_len = obj_len
|
51 |
+
self.tgt_len = tgt_len
|
52 |
+
assert self.transform is not None
|
53 |
+
self.pic_id_list, self.seg_id_list, self.object_list, self.segment_list, self.adj_list, self.des_list = \
|
54 |
+
get_data_from_csv(self.data_root)
|
55 |
+
self.data_csv = pd.read_csv(data_root, encoding='utf-8')
|
56 |
+
def get_data_from_csv_by_id(self, id, dict=None):
|
57 |
+
pic_id_list = self.data_csv['pic_id'].values
|
58 |
+
# {id: des_}
|
59 |
+
des_list = self.data_csv['des'].values
|
60 |
+
start_time = time.time()
|
61 |
+
for i in range(len(pic_id_list)):
|
62 |
+
if str(pic_id_list[i]) == str(id):
|
63 |
+
# print("find: str(pic_id_list[i]) == str(id)", time.time() - start_time)
|
64 |
+
return des_list[i]
|
65 |
+
return ""
|
66 |
+
|
67 |
+
def get_img_from_id(self, img_id):
|
68 |
+
img_filename = self.img_root
|
69 |
+
img_filename = img_filename + '/' + str(img_id) + '.jpg'
|
70 |
+
img = Image.open(img_filename)
|
71 |
+
if self.transform:
|
72 |
+
img = self.transform(img)
|
73 |
+
return img
|
74 |
+
|
75 |
+
def encode_text_bert(self, text):
|
76 |
+
tokens = []
|
77 |
+
tokens.append("[CLS]")
|
78 |
+
token_obj = self.tokenizer.tokenize(text)
|
79 |
+
for token in token_obj:
|
80 |
+
tokens.append(token)
|
81 |
+
tokens.append("[SEP]")
|
82 |
+
tokens = self.tokenizer.convert_tokens_to_ids(tokens)
|
83 |
+
return tokens
|
84 |
+
|
85 |
+
def get_all_from_id(self, img_id, obj_given):
|
86 |
+
img_id = str(img_id)
|
87 |
+
if img_id[0] == '#':
|
88 |
+
des = ""
|
89 |
+
else:
|
90 |
+
des = self.get_data_from_csv_by_id(img_id)
|
91 |
+
img = self.get_img_from_id(img_id)
|
92 |
+
des = self.encode_text_bert(des)
|
93 |
+
obj_given = self.encode_text_bert(obj_given)
|
94 |
+
while(len(des) < self.des_len):
|
95 |
+
des.append(100)
|
96 |
+
while(len(obj_given) < self.obj_len):
|
97 |
+
obj_given.append(0)
|
98 |
+
assert len(des) == self.des_len
|
99 |
+
return img, torch.from_numpy(np.array(des)), torch.from_numpy(np.array(obj_given))
|
100 |
+
|
101 |
+
def __getitem__(self, idx):
|
102 |
+
img_id = self.pic_id_list[idx]
|
103 |
+
img = self.get_img_from_id(img_id)
|
104 |
+
|
105 |
+
# des = self.des_list[idx].split('[,,;]')
|
106 |
+
des = re.split(',|;', str(self.des_list[idx]))
|
107 |
+
masked_des = "" # chinese
|
108 |
+
for i in range(len(des)):
|
109 |
+
if i != int(self.seg_id_list[idx]):
|
110 |
+
masked_des = masked_des + des[i] + ' '
|
111 |
+
|
112 |
+
obj = self.object_list[idx] # chinese
|
113 |
+
segment = self.segment_list[idx] # chinese
|
114 |
+
masked_des = self.encode_text_bert(masked_des)
|
115 |
+
obj = self.encode_text_bert(obj)
|
116 |
+
segment = self.encode_text_bert(segment)
|
117 |
+
while(len(masked_des) < self.des_len):
|
118 |
+
masked_des.append(100)
|
119 |
+
while(len(obj) < self.obj_len):
|
120 |
+
obj.append(0)
|
121 |
+
while(len(segment) < self.tgt_len):
|
122 |
+
segment.append(0)
|
123 |
+
|
124 |
+
assert len(masked_des) == self.des_len
|
125 |
+
assert len(obj) == self.obj_len
|
126 |
+
assert len(segment) == self.tgt_len
|
127 |
+
return img, np.array(masked_des), np.array(obj), np.array(segment), img_id
|
128 |
+
|
129 |
+
def __len__(self):
|
130 |
+
return len(self.pic_id_list)
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == '__main__':
|
135 |
+
data_root = r'E:\data\Download\fur\dataset\data_for_test1.csv'
|
136 |
+
split_root = ''
|
137 |
+
dataset_name = 'Furniture'
|
138 |
+
#
|
139 |
+
# get_data_from_csv(data_root)
|
140 |
+
# img_id = 550709
|
141 |
+
# img = get_img_from_id(img_id)
|
142 |
+
# plt.imshow(img)
|
143 |
+
# plt.show()
|
144 |
+
normalize = transforms.Normalize(mean=[0, 0, 0],
|
145 |
+
std=[1, 1, 1])
|
146 |
+
dataset = AttDesDataset(data_root, dataset_name, transform=transforms.Compose([
|
147 |
+
transforms.Resize((448,448)),
|
148 |
+
transforms.RandomHorizontalFlip(),
|
149 |
+
transforms.ToTensor(),
|
150 |
+
normalize,
|
151 |
+
]))
|
152 |
+
img, masked_des, obj, segment = dataset.__getitem__(100)
|
153 |
+
|
154 |
+
img_show = np.zeros((len(img[0]), len(img[0][0]), 3))
|
155 |
+
img_show[:, :, 0] = img[0]
|
156 |
+
img_show[:, :, 1] = img[1]
|
157 |
+
img_show[:, :, 2] = img[2]
|
158 |
+
plt.imshow(img_show)
|
159 |
+
plt.show()
|
160 |
+
print(masked_des, len(masked_des))
|
161 |
+
print(obj, len(obj))
|
162 |
+
print(segment, len(segment))
|
163 |
+
print(dataset.__len__())
|
164 |
+
|
165 |
+
# sentence_for_test = "原木地板的厚实与白色纱幔的轻飘营造朴素和浪漫的氛围,而一张编织餐椅灵动轻巧"
|
166 |
+
# tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
167 |
+
# tokenizer.tokenize(sentence_for_test)
|
168 |
+
# print(tokenizer.tokenize(sentence_for_test))
|
169 |
+
# print(tokenizer.convert_tokens_to_ids(sentence_for_test))
|
170 |
+
|
Model/AttDes/models/AttDes.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from pytorch_pretrained_bert.modeling import BertModel
|
6 |
+
|
7 |
+
|
8 |
+
class AttDes(nn.Module):
|
9 |
+
def __init__(self, args):
|
10 |
+
super(AttDes, self).__init__()
|
11 |
+
hidden_dim = args.AD_hidden_dim
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
Model/AttDes/models/Chinese_tokenizer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2403030d0e018aedffec4a62d69b124c350a5b1ef03035395dbcb3593deca8dd
|
3 |
+
size 142959
|
Model/AttDes/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Model/AttDes/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (145 Bytes). View file
|
|
Model/AttDes/models/language_model/bert.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Backbone modules.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
from typing import Dict, List
|
12 |
+
|
13 |
+
from pytorch_pretrained_bert.modeling import BertModel
|
14 |
+
from utils.misc import NestedTensor
|
15 |
+
|
16 |
+
|
17 |
+
class BERT(nn.Module):
|
18 |
+
def __init__(self, name: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num):
|
19 |
+
super().__init__()
|
20 |
+
if name == 'bert-base-uncased':
|
21 |
+
self.num_channels = 768
|
22 |
+
else:
|
23 |
+
self.num_channels = 1024
|
24 |
+
self.enc_num = enc_num
|
25 |
+
|
26 |
+
self.bert = BertModel.from_pretrained(name)
|
27 |
+
|
28 |
+
if not train_bert:
|
29 |
+
for parameter in self.bert.parameters():
|
30 |
+
parameter.requires_grad_(False)
|
31 |
+
|
32 |
+
def forward(self, tensor_list: NestedTensor):
|
33 |
+
|
34 |
+
if self.enc_num > 0:
|
35 |
+
all_encoder_layers, _ = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask)
|
36 |
+
# use the output of the X-th transformer encoder layers
|
37 |
+
xs = all_encoder_layers[self.enc_num - 1]
|
38 |
+
else:
|
39 |
+
xs = self.bert.embeddings.word_embeddings(tensor_list.tensors)
|
40 |
+
|
41 |
+
mask = tensor_list.mask.to(torch.bool)
|
42 |
+
mask = ~mask
|
43 |
+
out = NestedTensor(xs, mask)
|
44 |
+
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
Model/AttDes/models/prefixLM.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
author: yulong-XJTU
|
3 |
+
'''
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import copy
|
8 |
+
from AttDes.models.transformer import Transformer, subsequent_mask, ModelOne, Model005, Model006
|
9 |
+
from axial_positional_embedding import AxialPositionalEmbedding
|
10 |
+
from AttDes.models.resblock import BottleneckBlock
|
11 |
+
from random import randint
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
def clone(module,N):
|
15 |
+
'''copy the given module N times'''
|
16 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
17 |
+
|
18 |
+
class PrefixLM(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self, des_len, obj_len, tgt_len,
|
21 |
+
d_model=512,
|
22 |
+
input_resolution=224,
|
23 |
+
patch_size=16,
|
24 |
+
num_text_tokens=10000,
|
25 |
+
txt_seq_len=256,
|
26 |
+
prefix_txt_len=25,
|
27 |
+
target_txt_len=52,
|
28 |
+
max_trunc_txt_len=15,
|
29 |
+
heads=8,
|
30 |
+
enc_depth=12,
|
31 |
+
dec_depth=12,
|
32 |
+
d_ff=1024,
|
33 |
+
dropout=0.,
|
34 |
+
):
|
35 |
+
super(PrefixLM,self).__init__()
|
36 |
+
assert input_resolution % patch_size==0 and max_trunc_txt_len<=prefix_txt_len and max_trunc_txt_len<txt_seq_len
|
37 |
+
self.ResNet = nn.Sequential(*[nn.Conv2d(in_channels=3, out_channels=64, kernel_size=patch_size, stride=patch_size, bias=True),
|
38 |
+
BottleneckBlock(in_channels=64,out_channels=256,bottleneck_channels=64,),
|
39 |
+
BottleneckBlock(in_channels=256,out_channels=d_model,bottleneck_channels=128)])
|
40 |
+
self.des_len = des_len
|
41 |
+
self.obj_len = obj_len
|
42 |
+
self.tgt_len = tgt_len
|
43 |
+
self.txt_embed = nn.Embedding(num_text_tokens, d_model, padding_idx=0)
|
44 |
+
self.txt_pos_embed = nn.Embedding(self.des_len,d_model)
|
45 |
+
image_fmap_size = input_resolution // patch_size # 448 // 16
|
46 |
+
self.img_tokens_len=image_fmap_size ** 2
|
47 |
+
# self.img_pos_embed=nn.Embedding(self.img_tokens_len,d_model)
|
48 |
+
self.img_pos_embed = AxialPositionalEmbedding(d_model, axial_shape=(image_fmap_size, image_fmap_size))
|
49 |
+
self.txt_seq_len = txt_seq_len
|
50 |
+
self.target_txt_len = target_txt_len
|
51 |
+
self.prefix_txt_len = prefix_txt_len
|
52 |
+
|
53 |
+
self.max_trunc_txt_len=max_trunc_txt_len
|
54 |
+
self.num_text_tokens = num_text_tokens
|
55 |
+
self.dim_embed=d_model
|
56 |
+
self.input_resolution=input_resolution
|
57 |
+
self.patch_size=patch_size
|
58 |
+
# self.temperature = nn.Parameter(torch.tensor(1.)) # 论文中没提到
|
59 |
+
self.transformer=Transformer(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
|
60 |
+
self.ModelOne = Model005(d_model,heads,enc_depth,dec_depth,d_ff,dropout=dropout)
|
61 |
+
self.to_logits = nn.Sequential(
|
62 |
+
nn.LayerNorm(d_model),
|
63 |
+
nn.Linear(d_model, self.num_text_tokens)
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, img, des, obj, tgt, return_loss=False):
|
67 |
+
device = des.device
|
68 |
+
n = des.shape[0]
|
69 |
+
img_emed = self.ResNet(img)
|
70 |
+
img_emed = rearrange(img_emed,'b c h w -> b (h w) c')
|
71 |
+
img_emed = img_emed + self.img_pos_embed(img_emed)
|
72 |
+
del img
|
73 |
+
#add<CLS>, if you change the tokenizer, don't forget to change the token ID. another [SEP] token is added at the ending(in the tokenizer.py,please check.)
|
74 |
+
tgt = F.pad(tgt, (1, 0), value=4)
|
75 |
+
labels = tgt[:,1:]
|
76 |
+
tgt = tgt[:,:-1]
|
77 |
+
# print('des:', torch.min(des), torch.max(des))
|
78 |
+
des_embed = self.txt_embed(des)
|
79 |
+
des_embed = des_embed + self.txt_pos_embed(torch.arange(self.des_len, device=device))
|
80 |
+
|
81 |
+
obj_embed = self.txt_embed(obj)
|
82 |
+
obj_embed = obj_embed + self.txt_pos_embed(torch.arange(self.obj_len, device=device))
|
83 |
+
|
84 |
+
tgt_embed = self.txt_embed(tgt)
|
85 |
+
tgt_embed = tgt_embed + self.txt_pos_embed(torch.arange(self.tgt_len, device=device))
|
86 |
+
tgt_mask = subsequent_mask(self.tgt_len).to(device)
|
87 |
+
|
88 |
+
# baseline
|
89 |
+
# prefix = torch.cat((img_emed, des_embed, obj_embed), dim=1)
|
90 |
+
# tgt_mask = subsequent_mask(self.tgt_len).to(device)
|
91 |
+
# out = self.transformer(prefix, tgt_embed, tgt_mask=tgt_mask)
|
92 |
+
|
93 |
+
# ModelOne
|
94 |
+
|
95 |
+
out = Model005(q=obj_embed, k=img_emed, v=img_emed,
|
96 |
+
tgt_embeded=tgt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
97 |
+
tgt_mask=tgt_mask)
|
98 |
+
|
99 |
+
logits = self.to_logits(out)
|
100 |
+
return logits, labels
|
101 |
+
# if not return_loss:
|
102 |
+
# return logits
|
103 |
+
# # temp = self.temperature.exp()
|
104 |
+
# logits = rearrange(logits, 'b n c -> b c n')
|
105 |
+
# # logits=logits*temp #带温度参数
|
106 |
+
# loss=F.cross_entropy(logits,labels,ignore_index=0)
|
107 |
+
# return loss
|
108 |
+
|
Model/AttDes/models/resblock.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#代码借鉴于 https://github.com/facebookresearch/detectron2
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def c2_msra_fill(module: nn.Module) -> None:
|
8 |
+
"""
|
9 |
+
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
|
10 |
+
Also initializes `module.bias` to 0.
|
11 |
+
Args:
|
12 |
+
module (torch.nn.Module): module to initialize.
|
13 |
+
"""
|
14 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
15 |
+
if module.bias is not None:
|
16 |
+
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
|
17 |
+
# torch.Tensor]`.
|
18 |
+
nn.init.constant_(module.bias, 0)
|
19 |
+
|
20 |
+
def get_norm(norm, out_channels):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
24 |
+
or a callable that takes a channel number and returns
|
25 |
+
the normalization layer as a nn.Module.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
nn.Module or None: the normalization layer
|
29 |
+
"""
|
30 |
+
if norm is None:
|
31 |
+
return None
|
32 |
+
if isinstance(norm, str):
|
33 |
+
if len(norm) == 0:
|
34 |
+
return None
|
35 |
+
norm = {
|
36 |
+
"BN": torch.nn.BatchNorm2d,
|
37 |
+
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
38 |
+
#"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
|
39 |
+
"FrozenBN": FrozenBatchNorm2d,
|
40 |
+
"GN": lambda channels: nn.GroupNorm(32, channels),
|
41 |
+
# for debugging:
|
42 |
+
"nnSyncBN": nn.SyncBatchNorm,
|
43 |
+
#"naiveSyncBN": NaiveSyncBatchNorm,
|
44 |
+
# expose stats_mode N as an option to caller, required for zero-len inputs
|
45 |
+
#"naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
|
46 |
+
}[norm]
|
47 |
+
return norm(out_channels)
|
48 |
+
class Conv2d(torch.nn.Conv2d):
|
49 |
+
"""
|
50 |
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, *args, **kwargs):
|
54 |
+
"""
|
55 |
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
56 |
+
|
57 |
+
Args:
|
58 |
+
norm (nn.Module, optional): a normalization layer
|
59 |
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
60 |
+
|
61 |
+
It assumes that norm layer is used before activation.
|
62 |
+
"""
|
63 |
+
norm = kwargs.pop("norm", None)
|
64 |
+
activation = kwargs.pop("activation", None)
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
|
67 |
+
self.norm = norm
|
68 |
+
self.activation = activation
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
# torchscript does not support SyncBatchNorm yet
|
72 |
+
# https://github.com/pytorch/pytorch/issues/40507
|
73 |
+
# and we skip these codes in torchscript since:
|
74 |
+
# 1. currently we only support torchscript in evaluation mode
|
75 |
+
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
76 |
+
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
77 |
+
if not torch.jit.is_scripting():
|
78 |
+
if x.numel() == 0 and self.training:
|
79 |
+
# https://github.com/pytorch/pytorch/issues/12013
|
80 |
+
assert not isinstance(
|
81 |
+
self.norm, torch.nn.SyncBatchNorm
|
82 |
+
), "SyncBatchNorm does not support empty inputs!"
|
83 |
+
|
84 |
+
x = F.conv2d(
|
85 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
86 |
+
)
|
87 |
+
if self.norm is not None:
|
88 |
+
x = self.norm(x)
|
89 |
+
if self.activation is not None:
|
90 |
+
x = self.activation(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class FrozenBatchNorm2d(nn.Module):
|
95 |
+
"""
|
96 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
97 |
+
|
98 |
+
It contains non-trainable buffers called
|
99 |
+
"weight" and "bias", "running_mean", "running_var",
|
100 |
+
initialized to perform identity transformation.
|
101 |
+
|
102 |
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
103 |
+
which are computed from the original four parameters of BN.
|
104 |
+
The affine transform `x * weight + bias` will perform the equivalent
|
105 |
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
106 |
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
107 |
+
will be left unchanged as identity transformation.
|
108 |
+
|
109 |
+
Other pre-trained backbone models may contain all 4 parameters.
|
110 |
+
|
111 |
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
112 |
+
"""
|
113 |
+
|
114 |
+
_version = 3
|
115 |
+
|
116 |
+
def __init__(self, num_features, eps=1e-5):
|
117 |
+
super().__init__()
|
118 |
+
self.num_features = num_features
|
119 |
+
self.eps = eps
|
120 |
+
self.register_buffer("weight", torch.ones(num_features))
|
121 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
122 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
123 |
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
if x.requires_grad:
|
127 |
+
# When gradients are needed, F.batch_norm will use extra memory
|
128 |
+
# because its backward op computes gradients for weight/bias as well.
|
129 |
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
130 |
+
bias = self.bias - self.running_mean * scale
|
131 |
+
scale = scale.reshape(1, -1, 1, 1)
|
132 |
+
bias = bias.reshape(1, -1, 1, 1)
|
133 |
+
out_dtype = x.dtype # may be half
|
134 |
+
return x * scale.to(out_dtype) + bias.to(out_dtype)
|
135 |
+
else:
|
136 |
+
# When gradients are not needed, F.batch_norm is a single fused op
|
137 |
+
# and provide more optimization opportunities.
|
138 |
+
return F.batch_norm(
|
139 |
+
x,
|
140 |
+
self.running_mean,
|
141 |
+
self.running_var,
|
142 |
+
self.weight,
|
143 |
+
self.bias,
|
144 |
+
training=False,
|
145 |
+
eps=self.eps,
|
146 |
+
)
|
147 |
+
|
148 |
+
def _load_from_state_dict(
|
149 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
150 |
+
):
|
151 |
+
version = local_metadata.get("version", None)
|
152 |
+
|
153 |
+
if version is None or version < 2:
|
154 |
+
# No running_mean/var in early versions
|
155 |
+
# This will silent the warnings
|
156 |
+
if prefix + "running_mean" not in state_dict:
|
157 |
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
158 |
+
if prefix + "running_var" not in state_dict:
|
159 |
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
160 |
+
|
161 |
+
super()._load_from_state_dict(
|
162 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
163 |
+
)
|
164 |
+
|
165 |
+
def __repr__(self):
|
166 |
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
167 |
+
|
168 |
+
@classmethod
|
169 |
+
def convert_frozen_batchnorm(cls, module):
|
170 |
+
"""
|
171 |
+
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
module (torch.nn.Module):
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
178 |
+
Otherwise, in-place convert module and return it.
|
179 |
+
|
180 |
+
Similar to convert_sync_batchnorm in
|
181 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
182 |
+
"""
|
183 |
+
bn_module = nn.modules.batchnorm
|
184 |
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
185 |
+
res = module
|
186 |
+
if isinstance(module, bn_module):
|
187 |
+
res = cls(module.num_features)
|
188 |
+
if module.affine:
|
189 |
+
res.weight.data = module.weight.data.clone().detach()
|
190 |
+
res.bias.data = module.bias.data.clone().detach()
|
191 |
+
res.running_mean.data = module.running_mean.data
|
192 |
+
res.running_var.data = module.running_var.data
|
193 |
+
res.eps = module.eps
|
194 |
+
else:
|
195 |
+
for name, child in module.named_children():
|
196 |
+
new_child = cls.convert_frozen_batchnorm(child)
|
197 |
+
if new_child is not child:
|
198 |
+
res.add_module(name, new_child)
|
199 |
+
return res
|
200 |
+
|
201 |
+
|
202 |
+
class CNNBlockBase(nn.Module):
|
203 |
+
"""
|
204 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
205 |
+
The input and output of `forward()` method must be NCHW tensors.
|
206 |
+
The method can perform arbitrary computation but must match the given
|
207 |
+
channels and stride specification.
|
208 |
+
|
209 |
+
Attribute:
|
210 |
+
in_channels (int):
|
211 |
+
out_channels (int):
|
212 |
+
stride (int):
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, in_channels, out_channels, stride):
|
216 |
+
"""
|
217 |
+
The `__init__` method of any subclass should also contain these arguments.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
in_channels (int):
|
221 |
+
out_channels (int):
|
222 |
+
stride (int):
|
223 |
+
"""
|
224 |
+
super().__init__()
|
225 |
+
self.in_channels = in_channels
|
226 |
+
self.out_channels = out_channels
|
227 |
+
self.stride = stride
|
228 |
+
|
229 |
+
def freeze(self):
|
230 |
+
"""
|
231 |
+
Make this block not trainable.
|
232 |
+
This method sets all parameters to `requires_grad=False`,
|
233 |
+
and convert all BatchNorm layers to FrozenBatchNorm
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
the block itself
|
237 |
+
"""
|
238 |
+
for p in self.parameters():
|
239 |
+
p.requires_grad = False
|
240 |
+
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
241 |
+
return self
|
242 |
+
|
243 |
+
class BottleneckBlock(CNNBlockBase):
|
244 |
+
"""
|
245 |
+
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
246 |
+
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
247 |
+
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
248 |
+
"""
|
249 |
+
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
in_channels,
|
253 |
+
out_channels,
|
254 |
+
#*,
|
255 |
+
bottleneck_channels,
|
256 |
+
stride=1,
|
257 |
+
num_groups=1,
|
258 |
+
norm="BN",
|
259 |
+
stride_in_1x1=False,
|
260 |
+
dilation=1,
|
261 |
+
):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
265 |
+
"bottleneck" conv layers.
|
266 |
+
num_groups (int): number of groups for the 3x3 conv layer.
|
267 |
+
norm (str or callable): normalization for all conv layers.
|
268 |
+
See :func:`layers.get_norm` for supported format.
|
269 |
+
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
270 |
+
first 1x1 convolution or the bottleneck 3x3 convolution.
|
271 |
+
dilation (int): the dilation rate of the 3x3 conv layer.
|
272 |
+
"""
|
273 |
+
super().__init__(in_channels, out_channels, stride)
|
274 |
+
|
275 |
+
if in_channels != out_channels:
|
276 |
+
self.shortcut = Conv2d(
|
277 |
+
in_channels,
|
278 |
+
out_channels,
|
279 |
+
kernel_size=1,
|
280 |
+
stride=stride,
|
281 |
+
bias=False,
|
282 |
+
norm=get_norm(norm, out_channels),
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
self.shortcut = None
|
286 |
+
|
287 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
288 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
289 |
+
# stride in the 3x3 conv
|
290 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
291 |
+
|
292 |
+
self.conv1 = Conv2d(
|
293 |
+
in_channels,
|
294 |
+
bottleneck_channels,
|
295 |
+
kernel_size=1,
|
296 |
+
stride=stride_1x1,
|
297 |
+
bias=False,
|
298 |
+
norm=get_norm(norm, bottleneck_channels),
|
299 |
+
)
|
300 |
+
|
301 |
+
self.conv2 = Conv2d(
|
302 |
+
bottleneck_channels,
|
303 |
+
bottleneck_channels,
|
304 |
+
kernel_size=3,
|
305 |
+
stride=stride_3x3,
|
306 |
+
padding=1 * dilation,
|
307 |
+
bias=False,
|
308 |
+
groups=num_groups,
|
309 |
+
dilation=dilation,
|
310 |
+
norm=get_norm(norm, bottleneck_channels),
|
311 |
+
)
|
312 |
+
|
313 |
+
self.conv3 = Conv2d(
|
314 |
+
bottleneck_channels,
|
315 |
+
out_channels,
|
316 |
+
kernel_size=1,
|
317 |
+
bias=False,
|
318 |
+
norm=get_norm(norm, out_channels),
|
319 |
+
)
|
320 |
+
|
321 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
322 |
+
if layer is not None: # shortcut can be None
|
323 |
+
c2_msra_fill(layer)
|
324 |
+
|
325 |
+
# Zero-initialize the last normalization in each residual branch,
|
326 |
+
# so that at the beginning, the residual branch starts with zeros,
|
327 |
+
# and each residual block behaves like an identity.
|
328 |
+
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
329 |
+
# "For BN layers, the learnable scaling coefficient γ is initialized
|
330 |
+
# to be 1, except for each residual block's last BN
|
331 |
+
# where γ is initialized to be 0."
|
332 |
+
|
333 |
+
# nn.init.constant_(self.conv3.norm.weight, 0)
|
334 |
+
# TODO this somehow hurts performance when training GN models from scratch.
|
335 |
+
# Add it as an option when we need to use this code to train a backbone.
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
out = self.conv1(x)
|
339 |
+
out = F.relu_(out)
|
340 |
+
|
341 |
+
out = self.conv2(out)
|
342 |
+
out = F.relu_(out)
|
343 |
+
|
344 |
+
out = self.conv3(out)
|
345 |
+
|
346 |
+
if self.shortcut is not None:
|
347 |
+
shortcut = self.shortcut(x)
|
348 |
+
else:
|
349 |
+
shortcut = x
|
350 |
+
|
351 |
+
out += shortcut
|
352 |
+
out = F.relu_(out)
|
353 |
+
return out
|
Model/AttDes/models/tokenizer.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
2 |
+
# to give users a quick easy start to training DALL-E without doing BPE
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
# from transformers import BertTokenizer
|
8 |
+
|
9 |
+
import html
|
10 |
+
import os
|
11 |
+
from functools import lru_cache
|
12 |
+
from pathlib import Path
|
13 |
+
import ftfy
|
14 |
+
import regex as re
|
15 |
+
|
16 |
+
# OpenAI simple tokenizer
|
17 |
+
|
18 |
+
@lru_cache()
|
19 |
+
def default_bpe():
|
20 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
|
21 |
+
|
22 |
+
@lru_cache()
|
23 |
+
def bytes_to_unicode():
|
24 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
25 |
+
cs = bs[:]
|
26 |
+
n = 0
|
27 |
+
for b in range(2 ** 8):
|
28 |
+
if b not in bs:
|
29 |
+
bs.append(b)
|
30 |
+
cs.append(2 ** 8 + n)
|
31 |
+
n += 1
|
32 |
+
cs = [chr(n) for n in cs]
|
33 |
+
return dict(zip(bs, cs))
|
34 |
+
|
35 |
+
def get_pairs(word):
|
36 |
+
pairs = set()
|
37 |
+
prev_char = word[0]
|
38 |
+
for char in word[1:]:
|
39 |
+
pairs.add((prev_char, char))
|
40 |
+
prev_char = char
|
41 |
+
return pairs
|
42 |
+
|
43 |
+
def basic_clean(text):
|
44 |
+
text = ftfy.fix_text(text)
|
45 |
+
text = html.unescape(html.unescape(text))
|
46 |
+
return text.strip()
|
47 |
+
|
48 |
+
def whitespace_clean(text):
|
49 |
+
text = re.sub(r'\s+', ' ', text)
|
50 |
+
text = text.strip()
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
# chinese tokenizer
|
55 |
+
class ChineseTokenizer:
|
56 |
+
def __init__(self):
|
57 |
+
tokenizer = torch.load('./models/Chinese_tokenizer.pth') # BertTokenizer.from_pretrained('bert-base-chinese')
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
self.vocab_size = tokenizer.vocab_size+2
|
60 |
+
|
61 |
+
def decode(self, tokens):
|
62 |
+
if torch.is_tensor(tokens):
|
63 |
+
tokens = tokens.tolist()
|
64 |
+
|
65 |
+
tokens = [token for token in tokens if token not in (0,)]
|
66 |
+
return self.tokenizer.decode(tokens)
|
67 |
+
|
68 |
+
def encode(self, text,train=False):
|
69 |
+
t=torch.tensor(self.tokenizer.encode(text, add_special_tokens=False))
|
70 |
+
if train:
|
71 |
+
return torch.cat([t,torch.tensor([5])],dim=-1)
|
72 |
+
else:
|
73 |
+
return t
|
74 |
+
#special token: [CLS]==4,[SEP]==5, [PAD]==0,<bos>=7
|
75 |
+
|
76 |
+
def tokenize(self, texts, context_length = 77, truncate_text = False,train=True):
|
77 |
+
if isinstance(texts, str):
|
78 |
+
texts = [texts]
|
79 |
+
|
80 |
+
all_tokens = [self.encode(text,train=train) for text in texts]
|
81 |
+
|
82 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
83 |
+
for i, tokens in enumerate(all_tokens):
|
84 |
+
if len(tokens) > context_length:
|
85 |
+
if truncate_text:
|
86 |
+
tokens = tokens[:context_length]
|
87 |
+
else:
|
88 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
89 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
90 |
+
|
91 |
+
return result
|
92 |
+
|
Model/AttDes/models/transformer.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn .functional as F
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
#helpers
|
8 |
+
def clone(module,N):
|
9 |
+
'''copy the given module N times'''
|
10 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
11 |
+
def subsequent_mask(size):
|
12 |
+
attn_shape=(1,size,size)
|
13 |
+
subsequent_mask=np.triu(np.ones(attn_shape),k=1).astype(bool)
|
14 |
+
return torch.from_numpy(subsequent_mask)==False
|
15 |
+
|
16 |
+
|
17 |
+
def attention(query, key, value, mask=None, dropout=None):
|
18 |
+
d_k = query.size(-1)
|
19 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
20 |
+
if mask is not None:
|
21 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
22 |
+
p_attn = F.softmax(scores, dim=-1)
|
23 |
+
if dropout is not None:
|
24 |
+
p_attn = dropout(p_attn)
|
25 |
+
return torch.matmul(p_attn, value), p_attn
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadedAttention(nn.Module):
|
29 |
+
def __init__(self, h, d_model, dropout=0.1):
|
30 |
+
super(MultiHeadedAttention, self).__init__()
|
31 |
+
assert d_model % h == 0
|
32 |
+
self.d_k = d_model // h
|
33 |
+
self.h = h
|
34 |
+
self.linears = clone(nn.Linear(d_model, d_model), 4)
|
35 |
+
self.attn = None
|
36 |
+
self.dropout = nn.Dropout(p=dropout)
|
37 |
+
|
38 |
+
def forward(self, query, key, value, mask=None):
|
39 |
+
if mask is not None:
|
40 |
+
mask = mask.unsqueeze(1)
|
41 |
+
'''print('q:',query)
|
42 |
+
print('k:',key)
|
43 |
+
print('v:',value)'''
|
44 |
+
nbatchs = query.size(0)
|
45 |
+
query, key, value = [l(x).view(nbatchs, -1, self.h, self.d_k).transpose(1, 2) \
|
46 |
+
for l, x in zip(self.linears, (query, key, value))]
|
47 |
+
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
|
48 |
+
x = x.transpose(1, 2).contiguous().view(nbatchs, -1, self.h * self.d_k)
|
49 |
+
return self.linears[-1](x)
|
50 |
+
|
51 |
+
class Feedforward(nn.Module):
|
52 |
+
def __init__(self,d_model,d_ff,dropout=0.1):
|
53 |
+
super(Feedforward,self).__init__()
|
54 |
+
self.w_1=nn.Linear(d_model,d_ff)
|
55 |
+
self.w_2=nn.Linear(d_ff,d_model)
|
56 |
+
self.dropout=nn.Dropout(dropout)
|
57 |
+
|
58 |
+
def forward(self,x):
|
59 |
+
return self.w_2(self.dropout(F.relu((self.w_1(x)))))
|
60 |
+
|
61 |
+
|
62 |
+
class LayerNorm(nn.Module):
|
63 |
+
def __init__(self,features,eps=1e-6):
|
64 |
+
super(LayerNorm,self).__init__()
|
65 |
+
self.a_2=nn.Parameter(torch.ones(features))
|
66 |
+
self.b_2=nn.Parameter(torch.zeros(features))
|
67 |
+
self.eps=eps
|
68 |
+
|
69 |
+
def forward(self,x):
|
70 |
+
mean=x.mean(-1,keepdim=True)
|
71 |
+
std=x.std(-1,keepdim=True)
|
72 |
+
return self.a_2*(x-mean)/(std+self.eps)+self.b_2
|
73 |
+
|
74 |
+
|
75 |
+
class Generator(nn.Module):
|
76 |
+
def __init__(self,d_model,vocab):
|
77 |
+
super(Generator,self).__init__()
|
78 |
+
self.proj=nn.Linear(d_model,vocab)
|
79 |
+
|
80 |
+
def forward(self,x):
|
81 |
+
return F.log_softmax(self.proj(x),dim=-1)
|
82 |
+
|
83 |
+
|
84 |
+
# encoderLayer clone numbers times of enc_depth.
|
85 |
+
# 把encoderLayer重复enc_depth次;
|
86 |
+
class Encoder(nn.Module):
|
87 |
+
def __init__(self, layer, N):
|
88 |
+
'''N encoder layers '''
|
89 |
+
super(Encoder,self).__init__()
|
90 |
+
self.layers = clone(layer, N)
|
91 |
+
self.norm = LayerNorm(layer.size)
|
92 |
+
|
93 |
+
def forward(self,x,mask=None):
|
94 |
+
for layer in self.layers:
|
95 |
+
x = layer(x, mask)
|
96 |
+
return self.norm(x)
|
97 |
+
|
98 |
+
|
99 |
+
class SublayerConnection(nn.Module):
|
100 |
+
'''LayerNorm +subLayer+dropout+residual connection'''
|
101 |
+
def __init__(self,size,dropout):
|
102 |
+
super(SublayerConnection,self).__init__()
|
103 |
+
self.norm=LayerNorm(size)
|
104 |
+
self.dropout=nn.Dropout(dropout)
|
105 |
+
|
106 |
+
def forward(self,x,sublayer):
|
107 |
+
return x+self.dropout(sublayer(self.norm(x)))
|
108 |
+
|
109 |
+
|
110 |
+
class EncoderLayer(nn.Module):
|
111 |
+
def __init__(self,size,self_attn,feed_forward,dropout):
|
112 |
+
'''size is the embedding dimension'''
|
113 |
+
super(EncoderLayer,self).__init__()
|
114 |
+
self.self_attn = self_attn
|
115 |
+
self.feed_forward = feed_forward
|
116 |
+
self.sublayer = clone(SublayerConnection(size,dropout),2)
|
117 |
+
self.size = size
|
118 |
+
|
119 |
+
def forward(self,x,mask=None):
|
120 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x,mask))
|
121 |
+
return self.sublayer[1](x, self.feed_forward)
|
122 |
+
|
123 |
+
|
124 |
+
class Decoder(nn.Module):
|
125 |
+
def __init__(self,layer,N):
|
126 |
+
super(Decoder,self).__init__()
|
127 |
+
self.layers = clone(layer,N)
|
128 |
+
self.norm = LayerNorm(layer.size)
|
129 |
+
|
130 |
+
def forward(self,x, memory,src_mask=None,tgt_mask=None):
|
131 |
+
for layer in self.layers:
|
132 |
+
x = layer(x, memory, src_mask, tgt_mask)
|
133 |
+
return self.norm(x)
|
134 |
+
|
135 |
+
|
136 |
+
class DecoderLayer(nn.Module):
|
137 |
+
def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
|
138 |
+
super(DecoderLayer,self).__init__()
|
139 |
+
self.size = size
|
140 |
+
self.self_attn = self_attn
|
141 |
+
self.src_attn = src_attn
|
142 |
+
self.feed_forward = feed_forward
|
143 |
+
self.sublayer = clone(SublayerConnection(size,dropout),3)
|
144 |
+
|
145 |
+
def forward(self,x,memory,src_mask=None,tgt_mask=None):
|
146 |
+
m = memory
|
147 |
+
x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
|
148 |
+
x = self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask))
|
149 |
+
return self.sublayer[2](x,self.feed_forward)
|
150 |
+
|
151 |
+
|
152 |
+
class CrossAttLayer(nn.Module):
|
153 |
+
def __init__(self,d_model,self_attn,feed_forward,dropout=0.1):
|
154 |
+
super(CrossAttLayer, self).__init__()
|
155 |
+
self.size = d_model
|
156 |
+
self.self_attn = self_attn
|
157 |
+
# self.self_attn_0 = copy.deepcopy(self_attn)
|
158 |
+
self.feed_forward = feed_forward
|
159 |
+
self.dropout = nn.Dropout(dropout)
|
160 |
+
self.sublayer = clone(SublayerConnection(d_model, dropout), 2)
|
161 |
+
# self.sublayer = clone(SublayerConnection(d_model,dropout),3) # 可以改成三层的,第一层是self_attn
|
162 |
+
|
163 |
+
def forward(self,q,k,v,src_mask=None):
|
164 |
+
# k = self.sublayer[0](k, lambda k: self.self_attn_0(k,k,k))
|
165 |
+
# q = self.sublayer[0](q, lambda q: self.self_attn_0(q,q,q))
|
166 |
+
# x = self.sublayer[1](q, lambda q: self.self_attn(q,k,k,src_mask))
|
167 |
+
# x = self.sublayer[2](x, self.feed_forward)
|
168 |
+
x = self.sublayer[0](q, lambda q: self.self_attn(q,k,k,src_mask))
|
169 |
+
x = self.sublayer[1](x, self.feed_forward)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class CrossAtt(nn.Module):
|
174 |
+
def __init__(self, crossAttlayer, N=1):
|
175 |
+
super(CrossAtt, self).__init__()
|
176 |
+
self.layers = clone(crossAttlayer,N)
|
177 |
+
self.norm = LayerNorm(crossAttlayer.size)
|
178 |
+
|
179 |
+
def forward(self, q, k, v, src_mask=None):
|
180 |
+
for crossAttnLayer in self.layers:
|
181 |
+
q = crossAttnLayer(q, k, v, src_mask)
|
182 |
+
return self.norm(q)
|
183 |
+
|
184 |
+
class Transformer(nn.Module):
|
185 |
+
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
|
186 |
+
super(Transformer,self).__init__()
|
187 |
+
c = copy.deepcopy
|
188 |
+
attn = MultiHeadedAttention(heads,d_model)
|
189 |
+
ff = Feedforward(d_model,d_ff,dropout)
|
190 |
+
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
|
191 |
+
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
|
192 |
+
#self.register_buffer('src_mask', src_mask, persistent=False)
|
193 |
+
#self.register_buffer('tgt_mask', tgt_mask, persistent=False)
|
194 |
+
for p in self.parameters():
|
195 |
+
if p.dim() > 1:
|
196 |
+
nn.init.xavier_uniform_(p)
|
197 |
+
|
198 |
+
def forward(self,src_embeded,tgt_embeded,src_mask=None,tgt_mask=None):
|
199 |
+
return self.decode(self.encode(src_embeded,src_mask),tgt_embeded,src_mask,tgt_mask)
|
200 |
+
|
201 |
+
def encode(self,src_embeded,src_mask=None):
|
202 |
+
return self.encoder(src_embeded,src_mask)
|
203 |
+
|
204 |
+
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
|
205 |
+
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
|
206 |
+
|
207 |
+
|
208 |
+
class ModelOne(nn.Module):
|
209 |
+
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
|
210 |
+
super(ModelOne,self).__init__()
|
211 |
+
c = copy.deepcopy
|
212 |
+
attn = MultiHeadedAttention(heads,d_model)
|
213 |
+
ff = Feedforward(d_model,d_ff,dropout)
|
214 |
+
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
|
215 |
+
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
|
216 |
+
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
|
217 |
+
#self.register_buffer('src_mask', src_mask, persistent=False)
|
218 |
+
#self.register_buffer('tgt_mask', tgt_mask, persistent=False)
|
219 |
+
for p in self.parameters():
|
220 |
+
if p.dim() > 1:
|
221 |
+
nn.init.xavier_uniform_(p)
|
222 |
+
|
223 |
+
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
|
224 |
+
# x = self.CrossAtt(q, img_embed, img_embed)
|
225 |
+
# x2 = self.CrossAtt(q, des_embed, des_embed)
|
226 |
+
des_embed_self = self.CrossAtt(des_embed, des_embed, des_embed)
|
227 |
+
x3 = self.CrossAtt(img_embed, des_embed_self, des_embed_self)
|
228 |
+
# src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
|
229 |
+
src_embeded = torch.cat((x3, obj_embed), dim=1)
|
230 |
+
x = self.encode(src_embeded,src_mask)
|
231 |
+
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
|
232 |
+
return x
|
233 |
+
|
234 |
+
def encode(self,src_embeded,src_mask=None):
|
235 |
+
return self.encoder(src_embeded,src_mask)
|
236 |
+
|
237 |
+
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
|
238 |
+
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
|
239 |
+
|
240 |
+
class Model005(nn.Module):
|
241 |
+
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
|
242 |
+
super(Model005,self).__init__()
|
243 |
+
c = copy.deepcopy
|
244 |
+
attn = MultiHeadedAttention(heads,d_model)
|
245 |
+
ff = Feedforward(d_model,d_ff,dropout)
|
246 |
+
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
|
247 |
+
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
|
248 |
+
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
|
249 |
+
for p in self.parameters():
|
250 |
+
if p.dim() > 1:
|
251 |
+
nn.init.xavier_uniform_(p)
|
252 |
+
|
253 |
+
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
|
254 |
+
x = self.CrossAtt(q, img_embed, img_embed)
|
255 |
+
src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
|
256 |
+
x = self.encode(src_embeded,src_mask)
|
257 |
+
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
|
258 |
+
return x
|
259 |
+
|
260 |
+
def encode(self,src_embeded,src_mask=None):
|
261 |
+
return self.encoder(src_embeded,src_mask)
|
262 |
+
|
263 |
+
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
|
264 |
+
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
|
265 |
+
|
266 |
+
class Model006(nn.Module):
|
267 |
+
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
|
268 |
+
super(Model006,self).__init__()
|
269 |
+
c = copy.deepcopy
|
270 |
+
attn = MultiHeadedAttention(heads,d_model)
|
271 |
+
ff = Feedforward(d_model,d_ff,dropout)
|
272 |
+
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
|
273 |
+
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
|
274 |
+
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
|
275 |
+
for p in self.parameters():
|
276 |
+
if p.dim() > 1:
|
277 |
+
nn.init.xavier_uniform_(p)
|
278 |
+
|
279 |
+
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
|
280 |
+
x = self.CrossAtt(img_embed, img_embed, img_embed)
|
281 |
+
x = self.CrossAtt(obj_embed, x, x)
|
282 |
+
src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
|
283 |
+
x = self.encode(src_embeded,src_mask)
|
284 |
+
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
|
285 |
+
return x
|
286 |
+
|
287 |
+
def encode(self,src_embeded,src_mask=None):
|
288 |
+
return self.encoder(src_embeded,src_mask)
|
289 |
+
|
290 |
+
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
|
291 |
+
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
|
Model/AttDes/models/visual_model/Chinese_tokenizer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2403030d0e018aedffec4a62d69b124c350a5b1ef03035395dbcb3593deca8dd
|
3 |
+
size 142959
|
Model/AttDes/models/visual_model/backbone.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Backbone modules.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
12 |
+
from typing import Dict, List
|
13 |
+
|
14 |
+
from utils.misc import NestedTensor, is_main_process
|
15 |
+
|
16 |
+
from .position_encoding import build_position_encoding
|
17 |
+
|
18 |
+
|
19 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
20 |
+
"""
|
21 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
22 |
+
|
23 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
24 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
25 |
+
produce nans.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, n):
|
29 |
+
super(FrozenBatchNorm2d, self).__init__()
|
30 |
+
self.register_buffer("weight", torch.ones(n))
|
31 |
+
self.register_buffer("bias", torch.zeros(n))
|
32 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
33 |
+
self.register_buffer("running_var", torch.ones(n))
|
34 |
+
|
35 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
36 |
+
missing_keys, unexpected_keys, error_msgs):
|
37 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
38 |
+
if num_batches_tracked_key in state_dict:
|
39 |
+
del state_dict[num_batches_tracked_key]
|
40 |
+
|
41 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
42 |
+
state_dict, prefix, local_metadata, strict,
|
43 |
+
missing_keys, unexpected_keys, error_msgs)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
# move reshapes to the beginning
|
47 |
+
# to make it fuser-friendly
|
48 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
49 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
50 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
51 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
52 |
+
eps = 1e-5
|
53 |
+
scale = w * (rv + eps).rsqrt()
|
54 |
+
bias = b - rm * scale
|
55 |
+
return x * scale + bias
|
56 |
+
|
57 |
+
|
58 |
+
class BackboneBase(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, name:str, backbone: nn.Module, num_channels: int, return_interm_layers: bool):
|
61 |
+
super().__init__()
|
62 |
+
for name, parameter in backbone.named_parameters():
|
63 |
+
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
64 |
+
parameter.requires_grad_(False)
|
65 |
+
if return_interm_layers:
|
66 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
67 |
+
else:
|
68 |
+
return_layers = {'layer4': "0"}
|
69 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
70 |
+
self.num_channels = num_channels
|
71 |
+
|
72 |
+
def forward(self, tensor_list: NestedTensor):
|
73 |
+
xs = self.body(tensor_list.tensors)
|
74 |
+
out: Dict[str, NestedTensor] = {}
|
75 |
+
for name, x in xs.items():
|
76 |
+
m = tensor_list.mask
|
77 |
+
assert m is not None
|
78 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
79 |
+
out[name] = NestedTensor(x, mask)
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class Backbone(BackboneBase):
|
84 |
+
"""ResNet backbone with frozen BatchNorm."""
|
85 |
+
def __init__(self, name: str,
|
86 |
+
return_interm_layers: bool,
|
87 |
+
dilation: bool):
|
88 |
+
|
89 |
+
backbone = getattr(torchvision.models, name)(
|
90 |
+
replace_stride_with_dilation=[False, False, dilation],
|
91 |
+
pretrained=False, norm_layer=FrozenBatchNorm2d)
|
92 |
+
# pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
93 |
+
assert name in ('resnet50', 'resnet101')
|
94 |
+
num_channels = 2048
|
95 |
+
super().__init__(name, backbone, num_channels, return_interm_layers)
|
96 |
+
|
97 |
+
|
98 |
+
class Joiner(nn.Sequential):
|
99 |
+
def __init__(self, backbone, position_embedding):
|
100 |
+
super().__init__(backbone, position_embedding)
|
101 |
+
|
102 |
+
def forward(self, tensor_list: NestedTensor):
|
103 |
+
xs = self[0](tensor_list)
|
104 |
+
out: List[NestedTensor] = []
|
105 |
+
pos = []
|
106 |
+
for name, x in xs.items():
|
107 |
+
out.append(x)
|
108 |
+
# position encoding
|
109 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
110 |
+
|
111 |
+
return out, pos
|
112 |
+
|
113 |
+
|
114 |
+
def build_backbone(args):
|
115 |
+
position_embedding = build_position_encoding(args)
|
116 |
+
# train_backbone = args.lr_detr > 0
|
117 |
+
return_interm_layers = False
|
118 |
+
backbone = Backbone(args.backbone, return_interm_layers, args.dilation)
|
119 |
+
model = Joiner(backbone, position_embedding)
|
120 |
+
model.num_channels = backbone.num_channels
|
121 |
+
return model
|
Model/AttDes/models/visual_model/position_encoding.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the visual model.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from utils.misc import NestedTensor
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
18 |
+
super().__init__()
|
19 |
+
self.num_pos_feats = num_pos_feats
|
20 |
+
self.temperature = temperature
|
21 |
+
self.normalize = normalize
|
22 |
+
if scale is not None and normalize is False:
|
23 |
+
raise ValueError("normalize should be True if scale is passed")
|
24 |
+
if scale is None:
|
25 |
+
scale = 2 * math.pi
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(self, tensor_list: NestedTensor):
|
29 |
+
x = tensor_list.tensors
|
30 |
+
mask = tensor_list.mask
|
31 |
+
assert mask is not None
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
48 |
+
return pos
|
49 |
+
|
50 |
+
|
51 |
+
class PositionEmbeddingLearned(nn.Module):
|
52 |
+
"""
|
53 |
+
Absolute pos embedding, learned.
|
54 |
+
"""
|
55 |
+
def __init__(self, num_pos_feats=256):
|
56 |
+
super().__init__()
|
57 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
58 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
59 |
+
self.reset_parameters()
|
60 |
+
|
61 |
+
def reset_parameters(self):
|
62 |
+
nn.init.uniform_(self.row_embed.weight)
|
63 |
+
nn.init.uniform_(self.col_embed.weight)
|
64 |
+
|
65 |
+
def forward(self, tensor_list: NestedTensor):
|
66 |
+
x = tensor_list.tensors
|
67 |
+
h, w = x.shape[-2:]
|
68 |
+
i = torch.arange(w, device=x.device)
|
69 |
+
j = torch.arange(h, device=x.device)
|
70 |
+
x_emb = self.col_embed(i)
|
71 |
+
y_emb = self.row_embed(j)
|
72 |
+
pos = torch.cat([
|
73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
75 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
76 |
+
return pos
|
77 |
+
|
78 |
+
|
79 |
+
def build_position_encoding(args):
|
80 |
+
N_steps = args.hidden_dim // 2
|
81 |
+
if args.position_embedding in ('v2', 'sine'):
|
82 |
+
# TODO find a better way of exposing other arguments
|
83 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
84 |
+
elif args.position_embedding in ('v3', 'learned'):
|
85 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
86 |
+
else:
|
87 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
88 |
+
|
89 |
+
return position_embedding
|
Model/AttDes/validate_local.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
import torch
|
12 |
+
from nltk.translate import bleu_score
|
13 |
+
import sys
|
14 |
+
sys.path.append(r"E:\data\streamlit\Model\AttDes")
|
15 |
+
sys.path.append(r"E:\data\streamlit\Model\CLIP")
|
16 |
+
from AttDes import dataset
|
17 |
+
from AttDes.dataset import data_loader
|
18 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
19 |
+
import torchvision.transforms as transforms
|
20 |
+
import AttDes.models as models
|
21 |
+
from AttDes.models import prefixLM, tokenizer
|
22 |
+
|
23 |
+
import nltk
|
24 |
+
import jieba
|
25 |
+
# from engine import train_one_epoch, validate
|
26 |
+
#
|
27 |
+
# import utils.misc as utils
|
28 |
+
# from models import __init__
|
29 |
+
# from dataset import build_dataset
|
30 |
+
# from engine import train_one_epoch, validate_txt
|
31 |
+
|
32 |
+
from einops import rearrange
|
33 |
+
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
34 |
+
|
35 |
+
def get_args_parser():
|
36 |
+
parser = argparse.ArgumentParser('Set parser', add_help=False)
|
37 |
+
parser.add_argument('--device', default='cuda')
|
38 |
+
# parser.add_argument('--gpu_id', default='0', type=str)
|
39 |
+
|
40 |
+
# Dataset parameters
|
41 |
+
parser.add_argument('--data_root', type=str, default='/hy-nas/zhanghe/data/fur/txt/data_for_test2.csv')
|
42 |
+
parser.add_argument('--dataset_name', type=str, default='Furniture')
|
43 |
+
parser.add_argument('--img_root', type=str, default='/hy-nas/zhanghe/data/fur/processed_img')
|
44 |
+
parser.add_argument('--output_dir', default='./outputs/validate', help='path where to save, empty for no saving')
|
45 |
+
parser.add_argument('--seed', default=2022, type=int)
|
46 |
+
parser.add_argument('--resume', default='', help='resume for checkpoint')
|
47 |
+
parser.add_argument('--bert_model', default='bert-base-chinese', type=str)
|
48 |
+
parser.add_argument('--des_len', default=256, type=int)
|
49 |
+
parser.add_argument('--obj_len', default=8, type=int)
|
50 |
+
parser.add_argument('--tgt_len', default=35, type=int)
|
51 |
+
|
52 |
+
|
53 |
+
# Train parameters
|
54 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
55 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
56 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
57 |
+
parser.add_argument('--optimizer', default='adamw', type=str)
|
58 |
+
parser.add_argument('--lr_scheduler', default='step', type=str)
|
59 |
+
parser.add_argument('--lr_drop', default=5, type=int)
|
60 |
+
parser.add_argument('--start_epoch', default=0, type=int)
|
61 |
+
parser.add_argument('--epochs', default=1, type=int)
|
62 |
+
|
63 |
+
# Model parameters
|
64 |
+
parser.add_argument('--AD_hidden_dim', default=256, type=int)
|
65 |
+
parser.add_argument('--d_model', default=512, type=int)
|
66 |
+
# visual_model parameters
|
67 |
+
parser.add_argument('--backbone', default='resnet50', type=str,
|
68 |
+
help="Name of the convolutional backbone to use")
|
69 |
+
|
70 |
+
return parser
|
71 |
+
|
72 |
+
|
73 |
+
def main(args):
|
74 |
+
device = torch.device(args.device)
|
75 |
+
|
76 |
+
# seed = args.seed
|
77 |
+
# torch.manual_seed(seed)
|
78 |
+
# np.random.seed(seed)
|
79 |
+
# random.seed(seed)
|
80 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
81 |
+
std=[0.1673, 0.1695, 0.1705])
|
82 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
83 |
+
transforms.RandomHorizontalFlip(),
|
84 |
+
transforms.ToTensor(),
|
85 |
+
normalize,
|
86 |
+
])
|
87 |
+
dataset_all = AttDes.dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
88 |
+
des_len=args.des_len,
|
89 |
+
obj_len=args.obj_len,
|
90 |
+
tgt_len=args.tgt_len,
|
91 |
+
img_root=args.img_root,
|
92 |
+
transform=the_transforms)
|
93 |
+
|
94 |
+
dataloader_val = DataLoader(dataset_all,
|
95 |
+
batch_size=args.batch_size,
|
96 |
+
shuffle=False)
|
97 |
+
print("data loaded...")
|
98 |
+
|
99 |
+
Tokenizer = tokenizer.ChineseTokenizer()
|
100 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
101 |
+
input_resolution=448,
|
102 |
+
patch_size=16,
|
103 |
+
num_text_tokens=20000,
|
104 |
+
txt_seq_len=10000,
|
105 |
+
heads=4,
|
106 |
+
enc_depth=8,
|
107 |
+
dec_depth=8,
|
108 |
+
d_ff=1024,
|
109 |
+
dropout=0.1)
|
110 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
111 |
+
model.load_state_dict(torch.load('./outputs/005/checkpoint0019.pth'))
|
112 |
+
|
113 |
+
output_dir = Path(args.output_dir)
|
114 |
+
with (output_dir / "log.txt").open("a") as f:
|
115 |
+
f.write(str(args) + "\n")
|
116 |
+
|
117 |
+
print("start validate...")
|
118 |
+
start_time = time.time()
|
119 |
+
# optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
120 |
+
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2000)
|
121 |
+
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
|
122 |
+
for epoch in range(args.start_epoch, args.epochs):
|
123 |
+
validate_txt(args, model, dataloader_val, device, batch_size=args.batch_size)
|
124 |
+
|
125 |
+
total_time = time.time() - start_time
|
126 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
127 |
+
print('Validate time {}'.format(total_time_str))
|
128 |
+
|
129 |
+
|
130 |
+
def load_AttDes_Model(model_path, device):
|
131 |
+
parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
132 |
+
args = parser.parse_args()
|
133 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
134 |
+
std=[0.1673, 0.1695, 0.1705])
|
135 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
136 |
+
transforms.RandomHorizontalFlip(),
|
137 |
+
transforms.ToTensor(),
|
138 |
+
normalize,
|
139 |
+
])
|
140 |
+
dataset_all = data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
141 |
+
des_len=args.des_len,
|
142 |
+
obj_len=args.obj_len,
|
143 |
+
tgt_len=args.tgt_len,
|
144 |
+
img_root=args.img_root,
|
145 |
+
transform=the_transforms)
|
146 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
147 |
+
input_resolution=448,
|
148 |
+
patch_size=16,
|
149 |
+
num_text_tokens=20000,
|
150 |
+
txt_seq_len=10000,
|
151 |
+
heads=4,
|
152 |
+
enc_depth=8,
|
153 |
+
dec_depth=8,
|
154 |
+
d_ff=1024,
|
155 |
+
dropout=0.1)
|
156 |
+
time_1 = time.time()
|
157 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
158 |
+
model.load_state_dict(torch.load(model_path))
|
159 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
160 |
+
time_2 = time.time()
|
161 |
+
print('Load model takes {}s'.format(time_2 - time_1))
|
162 |
+
return model, dataset_all, tokenizer
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
def validate(img1_id, img2_id, obj, model_path):
|
167 |
+
parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
168 |
+
args = parser.parse_args()
|
169 |
+
device = torch.device(args.device)
|
170 |
+
#
|
171 |
+
# seed = args.seed
|
172 |
+
# torch.manual_seed(seed)
|
173 |
+
# np.random.seed(seed)
|
174 |
+
# random.seed(seed)
|
175 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
176 |
+
std=[0.1673, 0.1695, 0.1705])
|
177 |
+
|
178 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
179 |
+
transforms.RandomHorizontalFlip(),
|
180 |
+
transforms.ToTensor(),
|
181 |
+
normalize,
|
182 |
+
])
|
183 |
+
dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
184 |
+
des_len=args.des_len,
|
185 |
+
obj_len=args.obj_len,
|
186 |
+
tgt_len=args.tgt_len,
|
187 |
+
img_root=args.img_root,
|
188 |
+
transform=the_transforms)
|
189 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
190 |
+
input_resolution=448,
|
191 |
+
patch_size=16,
|
192 |
+
num_text_tokens=20000,
|
193 |
+
txt_seq_len=10000,
|
194 |
+
heads=4,
|
195 |
+
enc_depth=8,
|
196 |
+
dec_depth=8,
|
197 |
+
d_ff=1024,
|
198 |
+
dropout=0.1)
|
199 |
+
time_1 = time.time()
|
200 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
201 |
+
model.load_state_dict(torch.load(model_path))
|
202 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
203 |
+
time_2 = time.time()
|
204 |
+
print('Load model takes {}s'.format(time_2 - time_1))
|
205 |
+
out_list = []
|
206 |
+
|
207 |
+
label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img1_id, obj, device, tokenizer)
|
208 |
+
out_list.append([label_txt, output1, output2, output3])
|
209 |
+
label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img2_id, obj, device, tokenizer)
|
210 |
+
out_list.append([label_txt, output1, output2, output3])
|
211 |
+
return out_list
|
212 |
+
|
213 |
+
def get_data_from_csv_by_id(path, id):
|
214 |
+
data_csv = pd.read_csv(path, encoding='utf-8')
|
215 |
+
# print(data_csv)
|
216 |
+
pic_id_list = data_csv['pic_id'].values
|
217 |
+
des_list = data_csv['des'].values
|
218 |
+
|
219 |
+
for i in range(len(pic_id_list)):
|
220 |
+
if str(pic_id_list[i]) == str(id):
|
221 |
+
return des_list[i]
|
222 |
+
|
223 |
+
return ""
|
224 |
+
|
225 |
+
def validate_one_img(model, dataset_all, img_ids, obj_given, device, tokenizer):
|
226 |
+
batch_size = len(img_ids)
|
227 |
+
start_time = time.time()
|
228 |
+
model.eval()
|
229 |
+
imgs = []
|
230 |
+
dess = []
|
231 |
+
objs = []
|
232 |
+
for i in range(len(img_ids)):
|
233 |
+
img, des, obj = dataset_all.get_all_from_id(img_ids[i], obj_given[i])
|
234 |
+
# print("get img from id time:", time.time() - start_time) # 3s
|
235 |
+
imgs.append(img)
|
236 |
+
dess.append(des)
|
237 |
+
objs.append(obj)
|
238 |
+
img_data = torch.stack(imgs).to(device)
|
239 |
+
des_data = torch.stack(dess).to(device)
|
240 |
+
obj_data = torch.stack(objs).to(device)
|
241 |
+
# print("get batch time:", time.time() - start_time) # 3s
|
242 |
+
img_emed = model.ResNet(img_data)
|
243 |
+
img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
|
244 |
+
img_emed += model.img_pos_embed(img_emed)
|
245 |
+
|
246 |
+
des_embed = model.txt_embed(des_data)
|
247 |
+
des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
|
248 |
+
obj_embed = model.txt_embed(obj_data)
|
249 |
+
obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
|
250 |
+
|
251 |
+
|
252 |
+
tgt_txt = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + 101
|
253 |
+
tgt_txt_embed = model.txt_embed(tgt_txt)
|
254 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
|
255 |
+
|
256 |
+
# M_005
|
257 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
258 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
259 |
+
tgt_mask=None)
|
260 |
+
logits = model.to_logits(out)[:, -1]
|
261 |
+
_, index = logits.topk(3, dim=-1)
|
262 |
+
# value: tensor([[7.3227, 7.2289, 6.4169],
|
263 |
+
# [9.6868, 7.0598, 6.3911]], device='cuda:0', grad_fn= < TopkBackward0 >)
|
264 |
+
# index: tensor([[4677, 2199, 2647],
|
265 |
+
# [4510, 3763, 2145]], device='cuda:0')
|
266 |
+
sample_1st = index[:,0]
|
267 |
+
sample_2nd = index[:,1]
|
268 |
+
sample_3rd = index[:,2]
|
269 |
+
tgt_txt0 = tgt_txt
|
270 |
+
output_list = []
|
271 |
+
# print("get 1,2,3 sample time:", time.time() - start_time) # 0.01s
|
272 |
+
for sample in [sample_1st, sample_2nd, sample_3rd]:
|
273 |
+
tgt_txt = tgt_txt0
|
274 |
+
cur_len = 1
|
275 |
+
while (cur_len < model.tgt_len and sample.max() != 102): # 102 is the id of [SEP]
|
276 |
+
tgt_txt = torch.cat((tgt_txt, sample.unsqueeze(1)), dim=-1)
|
277 |
+
tgt_txt_embed = model.txt_embed(tgt_txt)
|
278 |
+
cur_len += 1
|
279 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
|
280 |
+
# out = model.transformer(prefix, tgt_txt_embed)
|
281 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
282 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
283 |
+
tgt_mask=None)
|
284 |
+
logits = model.to_logits(out)[:, -1]
|
285 |
+
sample = torch.argmax(logits, dim=-1)
|
286 |
+
# print("one batch sentence token time:", time.time() - start_time) # 0.6s
|
287 |
+
output_1 = []
|
288 |
+
for i in range(batch_size):
|
289 |
+
output_txt = []
|
290 |
+
for token in tgt_txt[i].tolist():
|
291 |
+
if token > 103:
|
292 |
+
output_txt.append(token)
|
293 |
+
output_txt = tokenizer.convert_ids_to_tokens(output_txt)
|
294 |
+
output_txt = ''.join(output_txt)
|
295 |
+
output_1.append(output_txt[1:])
|
296 |
+
output_list.append(output_1)
|
297 |
+
total_time = time.time() - start_time
|
298 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
299 |
+
print('Validate time {}'.format(total_time_str))
|
300 |
+
# print(output_list)
|
301 |
+
return output_list
|
302 |
+
|
303 |
+
|
304 |
+
def generate_texts(img_id, obj, model_path):
|
305 |
+
parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
306 |
+
args = parser.parse_args()
|
307 |
+
device = torch.device(args.device)
|
308 |
+
# seed = args.seed
|
309 |
+
# torch.manual_seed(seed)
|
310 |
+
# np.random.seed(seed)
|
311 |
+
# random.seed(seed)
|
312 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
313 |
+
std=[0.1673, 0.1695, 0.1705])
|
314 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
315 |
+
transforms.RandomHorizontalFlip(),
|
316 |
+
transforms.ToTensor(),
|
317 |
+
normalize,
|
318 |
+
])
|
319 |
+
dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
320 |
+
des_len=args.des_len,
|
321 |
+
obj_len=args.obj_len,
|
322 |
+
tgt_len=args.tgt_len,
|
323 |
+
img_root=args.img_root,
|
324 |
+
transform=the_transforms)
|
325 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
326 |
+
input_resolution=448,
|
327 |
+
patch_size=16,
|
328 |
+
num_text_tokens=20000,
|
329 |
+
txt_seq_len=10000,
|
330 |
+
heads=4,
|
331 |
+
enc_depth=8,
|
332 |
+
dec_depth=8,
|
333 |
+
d_ff=1024,
|
334 |
+
dropout=0.1)
|
335 |
+
time_1 = time.time()
|
336 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
337 |
+
model.load_state_dict(torch.load(model_path))
|
338 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
339 |
+
time_2 = time.time()
|
340 |
+
print('Load model takes {}s'.format(time_2 - time_1))
|
341 |
+
|
342 |
+
print("start generate_texts")
|
343 |
+
start_time = time.time()
|
344 |
+
end1_time = time.time()
|
345 |
+
model.eval()
|
346 |
+
img_data, des, obj_data, target, img_id, obj_given = dataset_all.get_all_from_id(img_id, obj)
|
347 |
+
|
348 |
+
img_data = img_data.unsqueeze(0).to(device)
|
349 |
+
des = des.unsqueeze(0).to(device)
|
350 |
+
obj_given = obj_given.unsqueeze(0).to(device)
|
351 |
+
label = target.unsqueeze(0).to(device)
|
352 |
+
|
353 |
+
img_emed = model.ResNet(img_data)
|
354 |
+
|
355 |
+
img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
|
356 |
+
img_emed += model.img_pos_embed(img_emed)
|
357 |
+
|
358 |
+
des_embed = model.txt_embed(des)
|
359 |
+
des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
|
360 |
+
obj_embed = model.txt_embed(obj_given)
|
361 |
+
obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
|
362 |
+
tgt_txt = torch.zeros(1, 1, dtype=torch.long, device=device) + 101
|
363 |
+
tgt_txt_embed = model.txt_embed(tgt_txt)
|
364 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
|
365 |
+
|
366 |
+
# M_005
|
367 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
368 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
369 |
+
tgt_mask=None)
|
370 |
+
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == '__main__':
|
375 |
+
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
376 |
+
# parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
377 |
+
# args = parser.parse_args()
|
378 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
379 |
+
# if args.output_dir:
|
380 |
+
# Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
381 |
+
# main(args)
|
382 |
+
model_name = '005'
|
383 |
+
model_path = r'E:\data\Download\models\attribute_desciption\outputs' + '/' + model_name + '/' + 'checkpoint0019.pth'
|
384 |
+
obj = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
|
385 |
+
"卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
|
386 |
+
|
387 |
+
out = generate_texts('550695', obj, model_path)
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
|
Model/AttDes/validate_local_gennerate.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from nltk.translate import bleu_score
|
14 |
+
|
15 |
+
import dataset.data_loader
|
16 |
+
import torch.backends.cudnn as cudnn
|
17 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
18 |
+
import torchvision.transforms as transforms
|
19 |
+
from models import prefixLM, tokenizer
|
20 |
+
import nltk
|
21 |
+
import jieba
|
22 |
+
# from engine import train_one_epoch, validate
|
23 |
+
#
|
24 |
+
# import utils.misc as utils
|
25 |
+
from models import __init__
|
26 |
+
# from dataset import build_dataset
|
27 |
+
# from engine import train_one_epoch, validate_txt
|
28 |
+
|
29 |
+
from einops import rearrange
|
30 |
+
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
31 |
+
|
32 |
+
def get_args_parser():
|
33 |
+
parser = argparse.ArgumentParser('Set parser', add_help=False)
|
34 |
+
parser.add_argument('--device', default='cuda')
|
35 |
+
parser.add_argument('--gpu_id', default='0', type=str)
|
36 |
+
|
37 |
+
# Dataset parameters
|
38 |
+
parser.add_argument('--data_root', type=str, default=r'E:\data\Download\fur\dataset\data_for_test2.csv')
|
39 |
+
parser.add_argument('--dataset_name', type=str, default='Furniture')
|
40 |
+
parser.add_argument('--img_root', type=str, default=r'E:\data\pictures')
|
41 |
+
parser.add_argument('--output_dir', default='./outputs/validate', help='path where to save, empty for no saving')
|
42 |
+
parser.add_argument('--seed', default=2022, type=int)
|
43 |
+
parser.add_argument('--resume', default='', help='resume for checkpoint')
|
44 |
+
parser.add_argument('--bert_model', default='bert-base-chinese', type=str)
|
45 |
+
parser.add_argument('--des_len', default=256, type=int)
|
46 |
+
parser.add_argument('--obj_len', default=8, type=int)
|
47 |
+
parser.add_argument('--tgt_len', default=35, type=int)
|
48 |
+
|
49 |
+
|
50 |
+
# Train parameters
|
51 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
52 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
53 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
54 |
+
parser.add_argument('--optimizer', default='adamw', type=str)
|
55 |
+
parser.add_argument('--lr_scheduler', default='step', type=str)
|
56 |
+
parser.add_argument('--lr_drop', default=5, type=int)
|
57 |
+
parser.add_argument('--start_epoch', default=0, type=int)
|
58 |
+
parser.add_argument('--epochs', default=1, type=int)
|
59 |
+
|
60 |
+
# Model parameters
|
61 |
+
parser.add_argument('--AD_hidden_dim', default=256, type=int)
|
62 |
+
parser.add_argument('--d_model', default=512, type=int)
|
63 |
+
# visual_model parameters
|
64 |
+
parser.add_argument('--backbone', default='resnet50', type=str,
|
65 |
+
help="Name of the convolutional backbone to use")
|
66 |
+
|
67 |
+
return parser
|
68 |
+
|
69 |
+
|
70 |
+
def main(args):
|
71 |
+
device = torch.device(args.device)
|
72 |
+
|
73 |
+
seed = args.seed
|
74 |
+
torch.manual_seed(seed)
|
75 |
+
np.random.seed(seed)
|
76 |
+
random.seed(seed)
|
77 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
78 |
+
std=[0.1673, 0.1695, 0.1705])
|
79 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
80 |
+
transforms.RandomHorizontalFlip(),
|
81 |
+
transforms.ToTensor(),
|
82 |
+
normalize,
|
83 |
+
])
|
84 |
+
dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
85 |
+
des_len=args.des_len,
|
86 |
+
obj_len=args.obj_len,
|
87 |
+
tgt_len=args.tgt_len,
|
88 |
+
img_root=args.img_root,
|
89 |
+
transform=the_transforms)
|
90 |
+
|
91 |
+
dataloader_val = DataLoader(dataset_all,
|
92 |
+
batch_size=args.batch_size,
|
93 |
+
shuffle=False)
|
94 |
+
print("data loaded...")
|
95 |
+
|
96 |
+
Tokenizer = tokenizer.ChineseTokenizer()
|
97 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
98 |
+
input_resolution=448,
|
99 |
+
patch_size=16,
|
100 |
+
num_text_tokens=20000,
|
101 |
+
txt_seq_len=10000,
|
102 |
+
heads=4,
|
103 |
+
enc_depth=8,
|
104 |
+
dec_depth=8,
|
105 |
+
d_ff=1024,
|
106 |
+
dropout=0.1)
|
107 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
108 |
+
model.load_state_dict(torch.load('./outputs/005/checkpoint0019.pth'))
|
109 |
+
|
110 |
+
output_dir = Path(args.output_dir)
|
111 |
+
with (output_dir / "log.txt").open("a") as f:
|
112 |
+
f.write(str(args) + "\n")
|
113 |
+
|
114 |
+
print("start validate...")
|
115 |
+
start_time = time.time()
|
116 |
+
# optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
117 |
+
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2000)
|
118 |
+
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
|
119 |
+
for epoch in range(args.start_epoch, args.epochs):
|
120 |
+
validate_txt(args, model, dataloader_val, device, batch_size=args.batch_size)
|
121 |
+
|
122 |
+
total_time = time.time() - start_time
|
123 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
124 |
+
print('Validate time {}'.format(total_time_str))
|
125 |
+
|
126 |
+
|
127 |
+
def validate(img1_id, img2_id, obj, model_path):
|
128 |
+
parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
129 |
+
args = parser.parse_args()
|
130 |
+
device = torch.device(args.device)
|
131 |
+
|
132 |
+
seed = args.seed
|
133 |
+
torch.manual_seed(seed)
|
134 |
+
np.random.seed(seed)
|
135 |
+
random.seed(seed)
|
136 |
+
normalize = transforms.Normalize(mean=[0.5024, 0.4993, 0.4992],
|
137 |
+
std=[0.1673, 0.1695, 0.1705])
|
138 |
+
the_transforms = transforms.Compose([transforms.Resize((448, 448)),
|
139 |
+
transforms.RandomHorizontalFlip(),
|
140 |
+
transforms.ToTensor(),
|
141 |
+
normalize,
|
142 |
+
])
|
143 |
+
dataset_all = dataset.data_loader.AttDesDataset(args.data_root, args.dataset_name,
|
144 |
+
des_len=args.des_len,
|
145 |
+
obj_len=args.obj_len,
|
146 |
+
tgt_len=args.tgt_len,
|
147 |
+
img_root=args.img_root,
|
148 |
+
transform=the_transforms)
|
149 |
+
PrefixLM_configure = dict(d_model=args.d_model, des_len=args.des_len, obj_len=args.obj_len, tgt_len=args.tgt_len,
|
150 |
+
input_resolution=448,
|
151 |
+
patch_size=16,
|
152 |
+
num_text_tokens=20000,
|
153 |
+
txt_seq_len=10000,
|
154 |
+
heads=4,
|
155 |
+
enc_depth=8,
|
156 |
+
dec_depth=8,
|
157 |
+
d_ff=1024,
|
158 |
+
dropout=0.1)
|
159 |
+
time_1 = time.time()
|
160 |
+
model = prefixLM.PrefixLM(**PrefixLM_configure).to(device)
|
161 |
+
model.load_state_dict(torch.load(model_path))
|
162 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
163 |
+
time_2 = time.time()
|
164 |
+
print('Load model takes {}s'.format(time_2 - time_1))
|
165 |
+
out_list = []
|
166 |
+
|
167 |
+
label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img1_id, obj, device, tokenizer)
|
168 |
+
out_list.append([label_txt, output1, output2, output3])
|
169 |
+
label_txt, output1, output2, output3 = validate_one_img(model, dataset_all, img2_id, obj, device, tokenizer)
|
170 |
+
out_list.append([label_txt, output1, output2, output3])
|
171 |
+
return out_list
|
172 |
+
|
173 |
+
def validate_one_img(model, dataset_all, img_id, obj, device, tokenizer):
|
174 |
+
# print("start validate...")
|
175 |
+
start_time = time.time()
|
176 |
+
end1_time = time.time()
|
177 |
+
model.eval()
|
178 |
+
print(obj)
|
179 |
+
img_data, des, obj_data, target, img_id, obj_given = dataset_all.get_all_from_id(img_id, obj)
|
180 |
+
print(obj_given)
|
181 |
+
img_data = img_data.unsqueeze(0).to(device)
|
182 |
+
des = des.unsqueeze(0).to(device)
|
183 |
+
obj_given = obj_given.unsqueeze(0).to(device)
|
184 |
+
label = target.unsqueeze(0).to(device)
|
185 |
+
|
186 |
+
img_emed = model.ResNet(img_data)
|
187 |
+
|
188 |
+
img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
|
189 |
+
img_emed += model.img_pos_embed(img_emed)
|
190 |
+
|
191 |
+
des_embed = model.txt_embed(des)
|
192 |
+
des_embed += model.txt_pos_embed(torch.arange(model.des_len, device=device))
|
193 |
+
obj_embed = model.txt_embed(obj_given)
|
194 |
+
obj_embed = obj_embed + model.txt_pos_embed(torch.arange(model.obj_len, device=device))
|
195 |
+
tgt_txt = torch.zeros(1, 1, dtype=torch.long, device=device) + 101
|
196 |
+
tgt_txt_embed = model.txt_embed(tgt_txt)
|
197 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.tgt_len)
|
198 |
+
|
199 |
+
|
200 |
+
# M_005
|
201 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
202 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
203 |
+
tgt_mask=None)
|
204 |
+
logits = model.to_logits(out)[:, -1]
|
205 |
+
sample = torch.argmax(logits, dim=-1)
|
206 |
+
value, index = logits.topk(3, dim=-1)
|
207 |
+
sample = index[0][0].unsqueeze(0)
|
208 |
+
sample_2nd = index[0][1].unsqueeze(0)
|
209 |
+
sample_3rd = index[0][2].unsqueeze(0)
|
210 |
+
tgt_txt_2nd = tgt_txt
|
211 |
+
tgt_txt_3rd = tgt_txt
|
212 |
+
|
213 |
+
cur_len = 1
|
214 |
+
while (cur_len < model.tgt_len and sample != 102): # 102 is the id of [SEP]
|
215 |
+
tgt_txt = torch.cat((tgt_txt, sample.unsqueeze(1)), dim=-1)
|
216 |
+
tgt_txt_embed = model.txt_embed(tgt_txt)
|
217 |
+
cur_len += 1
|
218 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
|
219 |
+
# out = model.transformer(prefix, tgt_txt_embed)
|
220 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
221 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
222 |
+
tgt_mask=None)
|
223 |
+
logits = model.to_logits(out)[:, -1]
|
224 |
+
sample = torch.argmax(logits, dim=-1)
|
225 |
+
label_txt = []
|
226 |
+
output_txt = []
|
227 |
+
obj_txt = []
|
228 |
+
for token in des[0].tolist():
|
229 |
+
if token > 103:
|
230 |
+
label_txt.append(token)
|
231 |
+
for token in tgt_txt[0].tolist():
|
232 |
+
if token > 103:
|
233 |
+
output_txt.append(token)
|
234 |
+
# for token in obj_data[0].tolist():
|
235 |
+
# if token > 103:
|
236 |
+
# obj_txt.append(token)
|
237 |
+
label_txt = tokenizer.convert_ids_to_tokens(label_txt)
|
238 |
+
label_txt = ''.join(label_txt)
|
239 |
+
|
240 |
+
# obj_txt = tokenizer.convert_ids_to_tokens(obj_txt)
|
241 |
+
output_txt = tokenizer.convert_ids_to_tokens(output_txt)
|
242 |
+
output1 = ''.join(output_txt)
|
243 |
+
|
244 |
+
# 2nd
|
245 |
+
cur_len = 1
|
246 |
+
while (cur_len < model.tgt_len and sample_2nd != 102): # 102 is the id of [SEP]
|
247 |
+
tgt_txt_2nd = torch.cat((tgt_txt_2nd, sample_2nd.unsqueeze(1)), dim=-1)
|
248 |
+
tgt_txt_embed = model.txt_embed(tgt_txt_2nd)
|
249 |
+
cur_len += 1
|
250 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
|
251 |
+
# out = model.transformer(prefix, tgt_txt_embed)
|
252 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
253 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
254 |
+
tgt_mask=None)
|
255 |
+
logits = model.to_logits(out)[:, -1]
|
256 |
+
# logits = logits[:, :-26]
|
257 |
+
# print(logits)
|
258 |
+
sample_2nd = torch.argmax(logits, dim=-1)
|
259 |
+
|
260 |
+
output_txt = []
|
261 |
+
for token in tgt_txt_2nd[0].tolist():
|
262 |
+
if token > 103:
|
263 |
+
output_txt.append(token)
|
264 |
+
output_txt = tokenizer.convert_ids_to_tokens(output_txt)
|
265 |
+
output2 = ''.join(output_txt)
|
266 |
+
# 3rd
|
267 |
+
cur_len = 1
|
268 |
+
while (cur_len < model.tgt_len and sample_3rd != 102): # 102 is the id of [SEP]
|
269 |
+
tgt_txt_3rd = torch.cat((tgt_txt_3rd, sample_3rd.unsqueeze(1)), dim=-1)
|
270 |
+
tgt_txt_embed = model.txt_embed(tgt_txt_3rd)
|
271 |
+
cur_len += 1
|
272 |
+
tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device))
|
273 |
+
# out = model.transformer(prefix, tgt_txt_embed)
|
274 |
+
out = model.ModelOne(q=obj_embed, k=img_emed, v=img_emed,
|
275 |
+
tgt_embeded=tgt_txt_embed, des_embed=des_embed, obj_embed=obj_embed, img_embed=img_emed,
|
276 |
+
tgt_mask=None)
|
277 |
+
logits = model.to_logits(out)[:, -1]
|
278 |
+
# logits = logits[:, :-26]
|
279 |
+
sample_3rd = torch.argmax(logits, dim=-1)
|
280 |
+
|
281 |
+
output_txt = []
|
282 |
+
for token in tgt_txt_3rd[0].tolist():
|
283 |
+
if token > 103:
|
284 |
+
output_txt.append(token)
|
285 |
+
output_txt = tokenizer.convert_ids_to_tokens(output_txt)
|
286 |
+
output3 = ''.join(output_txt)
|
287 |
+
|
288 |
+
total_time = time.time() - start_time
|
289 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
290 |
+
print(output1)
|
291 |
+
print(output2)
|
292 |
+
print(output3)
|
293 |
+
print('Validate time {}'.format(total_time_str))
|
294 |
+
return label_txt, output1, output2, output3
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == '__main__':
|
298 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
299 |
+
# parser = argparse.ArgumentParser('AttDes training script', parents=[get_args_parser()])
|
300 |
+
# args = parser.parse_args()
|
301 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
302 |
+
# if args.output_dir:
|
303 |
+
# Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
304 |
+
# main(args)
|
305 |
+
model_name = '005'
|
306 |
+
model_path = r'E:\data\Download\models\attribute_desciption\outputs' + '/' + model_name + '/' + 'checkpoint0019.pth'
|
307 |
+
objs = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
|
308 |
+
"卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
|
309 |
+
for obj in objs:
|
310 |
+
print(obj)
|
311 |
+
out = validate('550695', '550567', obj, model_path)
|
312 |
+
sentences1 = out[0][0].replace(';', ',').split(',')
|
313 |
+
# gt = ""
|
314 |
+
#
|
315 |
+
# for i in sentences1:
|
316 |
+
# if obj in i:
|
317 |
+
# gt = i
|
318 |
+
# gt = " ".join(jieba.cut(gt))
|
319 |
+
# print(gt)
|
320 |
+
# for i in out[0]:
|
321 |
+
# i = " ".join(jieba.cut(i))
|
322 |
+
# print(i)
|
323 |
+
# print(gt)
|
324 |
+
# bleu = nltk.translate.bleu_score.sentence_bleu([i], gt)
|
325 |
+
# print(bleu)
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
Model/CLIP/cn_clip/__init__.py
ADDED
File without changes
|
Model/CLIP/cn_clip/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (144 Bytes). View file
|
|
Model/CLIP/cn_clip/clip/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bert_tokenizer import FullTokenizer
|
2 |
+
|
3 |
+
_tokenizer = FullTokenizer()
|
4 |
+
from .utils import load_from_name, available_models, tokenize, image_transform, load
|
5 |
+
|
Model/CLIP/cn_clip/clip/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (352 Bytes). View file
|
|
Model/CLIP/cn_clip/clip/__pycache__/bert_tokenizer.cpython-38.pyc
ADDED
Binary file (11.2 kB). View file
|
|
Model/CLIP/cn_clip/clip/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (5.99 kB). View file
|
|
Model/CLIP/cn_clip/clip/bert_tokenizer.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Tokenization classes."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import collections
|
23 |
+
import re
|
24 |
+
import unicodedata
|
25 |
+
import six
|
26 |
+
from functools import lru_cache
|
27 |
+
import os
|
28 |
+
|
29 |
+
@lru_cache()
|
30 |
+
def default_vocab():
|
31 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt")
|
32 |
+
|
33 |
+
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
34 |
+
"""Checks whether the casing config is consistent with the checkpoint name."""
|
35 |
+
|
36 |
+
# The casing has to be passed in by the user and there is no explicit check
|
37 |
+
# as to whether it matches the checkpoint. The casing information probably
|
38 |
+
# should have been stored in the bert_config.json file, but it's not, so
|
39 |
+
# we have to heuristically detect it to validate.
|
40 |
+
|
41 |
+
if not init_checkpoint:
|
42 |
+
return
|
43 |
+
|
44 |
+
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
45 |
+
if m is None:
|
46 |
+
return
|
47 |
+
|
48 |
+
model_name = m.group(1)
|
49 |
+
|
50 |
+
lower_models = [
|
51 |
+
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
52 |
+
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
53 |
+
]
|
54 |
+
|
55 |
+
cased_models = [
|
56 |
+
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
57 |
+
"multi_cased_L-12_H-768_A-12"
|
58 |
+
]
|
59 |
+
|
60 |
+
is_bad_config = False
|
61 |
+
if model_name in lower_models and not do_lower_case:
|
62 |
+
is_bad_config = True
|
63 |
+
actual_flag = "False"
|
64 |
+
case_name = "lowercased"
|
65 |
+
opposite_flag = "True"
|
66 |
+
|
67 |
+
if model_name in cased_models and do_lower_case:
|
68 |
+
is_bad_config = True
|
69 |
+
actual_flag = "True"
|
70 |
+
case_name = "cased"
|
71 |
+
opposite_flag = "False"
|
72 |
+
|
73 |
+
if is_bad_config:
|
74 |
+
raise ValueError(
|
75 |
+
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
76 |
+
"However, `%s` seems to be a %s model, so you "
|
77 |
+
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
78 |
+
"how the model was pre-training. If this error is wrong, please "
|
79 |
+
"just comment out this check." % (actual_flag, init_checkpoint,
|
80 |
+
model_name, case_name, opposite_flag))
|
81 |
+
|
82 |
+
|
83 |
+
def convert_to_unicode(text):
|
84 |
+
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
85 |
+
if six.PY3:
|
86 |
+
if isinstance(text, str):
|
87 |
+
return text
|
88 |
+
elif isinstance(text, bytes):
|
89 |
+
return text.decode("utf-8", "ignore")
|
90 |
+
else:
|
91 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
92 |
+
elif six.PY2:
|
93 |
+
if isinstance(text, str):
|
94 |
+
return text.decode("utf-8", "ignore")
|
95 |
+
elif isinstance(text, unicode):
|
96 |
+
return text
|
97 |
+
else:
|
98 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
99 |
+
else:
|
100 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
101 |
+
|
102 |
+
|
103 |
+
def printable_text(text):
|
104 |
+
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
105 |
+
|
106 |
+
# These functions want `str` for both Python2 and Python3, but in one case
|
107 |
+
# it's a Unicode string and in the other it's a byte string.
|
108 |
+
if six.PY3:
|
109 |
+
if isinstance(text, str):
|
110 |
+
return text
|
111 |
+
elif isinstance(text, bytes):
|
112 |
+
return text.decode("utf-8", "ignore")
|
113 |
+
else:
|
114 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
115 |
+
elif six.PY2:
|
116 |
+
if isinstance(text, str):
|
117 |
+
return text
|
118 |
+
elif isinstance(text, unicode):
|
119 |
+
return text.encode("utf-8")
|
120 |
+
else:
|
121 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
122 |
+
else:
|
123 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
124 |
+
|
125 |
+
|
126 |
+
def load_vocab(vocab_file):
|
127 |
+
"""Loads a vocabulary file into a dictionary."""
|
128 |
+
vocab = collections.OrderedDict()
|
129 |
+
index = 0
|
130 |
+
with open(vocab_file, "r", encoding='utf-8') as reader:
|
131 |
+
while True:
|
132 |
+
token = convert_to_unicode(reader.readline())
|
133 |
+
if not token:
|
134 |
+
break
|
135 |
+
token = token.strip()
|
136 |
+
vocab[token] = index
|
137 |
+
index += 1
|
138 |
+
return vocab
|
139 |
+
|
140 |
+
|
141 |
+
def convert_by_vocab(vocab, items):
|
142 |
+
"""Converts a sequence of [tokens|ids] using the vocab."""
|
143 |
+
output = []
|
144 |
+
for item in items:
|
145 |
+
output.append(vocab[item])
|
146 |
+
return output
|
147 |
+
|
148 |
+
|
149 |
+
def convert_tokens_to_ids(vocab, tokens):
|
150 |
+
return convert_by_vocab(vocab, tokens)
|
151 |
+
|
152 |
+
|
153 |
+
def convert_ids_to_tokens(inv_vocab, ids):
|
154 |
+
return convert_by_vocab(inv_vocab, ids)
|
155 |
+
|
156 |
+
|
157 |
+
def whitespace_tokenize(text):
|
158 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
159 |
+
text = text.strip()
|
160 |
+
if not text:
|
161 |
+
return []
|
162 |
+
tokens = text.split()
|
163 |
+
return tokens
|
164 |
+
|
165 |
+
|
166 |
+
class FullTokenizer(object):
|
167 |
+
"""Runs end-to-end tokenziation."""
|
168 |
+
|
169 |
+
def __init__(self, vocab_file=default_vocab(), do_lower_case=True):
|
170 |
+
self.vocab = load_vocab(vocab_file)
|
171 |
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
172 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
173 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
174 |
+
|
175 |
+
def tokenize(self, text):
|
176 |
+
split_tokens = []
|
177 |
+
for token in self.basic_tokenizer.tokenize(text):
|
178 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
179 |
+
split_tokens.append(sub_token)
|
180 |
+
|
181 |
+
return split_tokens
|
182 |
+
|
183 |
+
def convert_tokens_to_ids(self, tokens):
|
184 |
+
return convert_by_vocab(self.vocab, tokens)
|
185 |
+
|
186 |
+
def convert_ids_to_tokens(self, ids):
|
187 |
+
return convert_by_vocab(self.inv_vocab, ids)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
191 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
192 |
+
|
193 |
+
def clean_up_tokenization(out_string):
|
194 |
+
""" Clean up a list of simple English tokenization artifacts
|
195 |
+
like spaces before punctuations and abreviated forms.
|
196 |
+
"""
|
197 |
+
out_string = (
|
198 |
+
out_string.replace(" .", ".")
|
199 |
+
.replace(" ?", "?")
|
200 |
+
.replace(" !", "!")
|
201 |
+
.replace(" ,", ",")
|
202 |
+
.replace(" ' ", "'")
|
203 |
+
.replace(" n't", "n't")
|
204 |
+
.replace(" 'm", "'m")
|
205 |
+
.replace(" 's", "'s")
|
206 |
+
.replace(" 've", "'ve")
|
207 |
+
.replace(" 're", "'re")
|
208 |
+
)
|
209 |
+
return out_string
|
210 |
+
|
211 |
+
text = ' '.join(tokens).replace(' ##', '').strip()
|
212 |
+
if clean_up_tokenization_spaces:
|
213 |
+
clean_text = clean_up_tokenization(text)
|
214 |
+
return clean_text
|
215 |
+
else:
|
216 |
+
return text
|
217 |
+
|
218 |
+
def vocab_size(self):
|
219 |
+
return len(self.vocab)
|
220 |
+
|
221 |
+
|
222 |
+
class BasicTokenizer(object):
|
223 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
224 |
+
|
225 |
+
def __init__(self, do_lower_case=True):
|
226 |
+
"""Constructs a BasicTokenizer.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
do_lower_case: Whether to lower case the input.
|
230 |
+
"""
|
231 |
+
self.do_lower_case = do_lower_case
|
232 |
+
|
233 |
+
def tokenize(self, text):
|
234 |
+
"""Tokenizes a piece of text."""
|
235 |
+
text = convert_to_unicode(text)
|
236 |
+
text = self._clean_text(text)
|
237 |
+
|
238 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
239 |
+
# models. This is also applied to the English models now, but it doesn't
|
240 |
+
# matter since the English models were not trained on any Chinese data
|
241 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
242 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
243 |
+
# words in the English Wikipedia.).
|
244 |
+
text = self._tokenize_chinese_chars(text)
|
245 |
+
|
246 |
+
orig_tokens = whitespace_tokenize(text)
|
247 |
+
split_tokens = []
|
248 |
+
for token in orig_tokens:
|
249 |
+
if self.do_lower_case:
|
250 |
+
token = token.lower()
|
251 |
+
token = self._run_strip_accents(token)
|
252 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
253 |
+
|
254 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
255 |
+
return output_tokens
|
256 |
+
|
257 |
+
def _run_strip_accents(self, text):
|
258 |
+
"""Strips accents from a piece of text."""
|
259 |
+
text = unicodedata.normalize("NFD", text)
|
260 |
+
output = []
|
261 |
+
for char in text:
|
262 |
+
cat = unicodedata.category(char)
|
263 |
+
if cat == "Mn":
|
264 |
+
continue
|
265 |
+
output.append(char)
|
266 |
+
return "".join(output)
|
267 |
+
|
268 |
+
def _run_split_on_punc(self, text):
|
269 |
+
"""Splits punctuation on a piece of text."""
|
270 |
+
chars = list(text)
|
271 |
+
i = 0
|
272 |
+
start_new_word = True
|
273 |
+
output = []
|
274 |
+
while i < len(chars):
|
275 |
+
char = chars[i]
|
276 |
+
if _is_punctuation(char):
|
277 |
+
output.append([char])
|
278 |
+
start_new_word = True
|
279 |
+
else:
|
280 |
+
if start_new_word:
|
281 |
+
output.append([])
|
282 |
+
start_new_word = False
|
283 |
+
output[-1].append(char)
|
284 |
+
i += 1
|
285 |
+
|
286 |
+
return ["".join(x) for x in output]
|
287 |
+
|
288 |
+
def _tokenize_chinese_chars(self, text):
|
289 |
+
"""Adds whitespace around any CJK character."""
|
290 |
+
output = []
|
291 |
+
for char in text:
|
292 |
+
cp = ord(char)
|
293 |
+
if self._is_chinese_char(cp):
|
294 |
+
output.append(" ")
|
295 |
+
output.append(char)
|
296 |
+
output.append(" ")
|
297 |
+
else:
|
298 |
+
output.append(char)
|
299 |
+
return "".join(output)
|
300 |
+
|
301 |
+
def _is_chinese_char(self, cp):
|
302 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
303 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
304 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
305 |
+
#
|
306 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
307 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
308 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
309 |
+
# space-separated words, so they are not treated specially and handled
|
310 |
+
# like the all of the other languages.
|
311 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
312 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
313 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
314 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
315 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
316 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
317 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
318 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
319 |
+
return True
|
320 |
+
|
321 |
+
return False
|
322 |
+
|
323 |
+
def _clean_text(self, text):
|
324 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
325 |
+
output = []
|
326 |
+
for char in text:
|
327 |
+
cp = ord(char)
|
328 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
329 |
+
continue
|
330 |
+
if _is_whitespace(char):
|
331 |
+
output.append(" ")
|
332 |
+
else:
|
333 |
+
output.append(char)
|
334 |
+
return "".join(output)
|
335 |
+
|
336 |
+
|
337 |
+
class WordpieceTokenizer(object):
|
338 |
+
"""Runs WordPiece tokenziation."""
|
339 |
+
|
340 |
+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
341 |
+
self.vocab = vocab
|
342 |
+
self.unk_token = unk_token
|
343 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
344 |
+
|
345 |
+
def tokenize(self, text):
|
346 |
+
"""Tokenizes a piece of text into its word pieces.
|
347 |
+
|
348 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
349 |
+
using the given vocabulary.
|
350 |
+
|
351 |
+
For example:
|
352 |
+
input = "unaffable"
|
353 |
+
output = ["un", "##aff", "##able"]
|
354 |
+
|
355 |
+
Args:
|
356 |
+
text: A single token or whitespace separated tokens. This should have
|
357 |
+
already been passed through `BasicTokenizer.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
A list of wordpiece tokens.
|
361 |
+
"""
|
362 |
+
|
363 |
+
text = convert_to_unicode(text)
|
364 |
+
|
365 |
+
output_tokens = []
|
366 |
+
for token in whitespace_tokenize(text):
|
367 |
+
chars = list(token)
|
368 |
+
if len(chars) > self.max_input_chars_per_word:
|
369 |
+
output_tokens.append(self.unk_token)
|
370 |
+
continue
|
371 |
+
|
372 |
+
is_bad = False
|
373 |
+
start = 0
|
374 |
+
sub_tokens = []
|
375 |
+
while start < len(chars):
|
376 |
+
end = len(chars)
|
377 |
+
cur_substr = None
|
378 |
+
while start < end:
|
379 |
+
substr = "".join(chars[start:end])
|
380 |
+
if start > 0:
|
381 |
+
substr = "##" + substr
|
382 |
+
if substr in self.vocab:
|
383 |
+
cur_substr = substr
|
384 |
+
break
|
385 |
+
end -= 1
|
386 |
+
if cur_substr is None:
|
387 |
+
is_bad = True
|
388 |
+
break
|
389 |
+
sub_tokens.append(cur_substr)
|
390 |
+
start = end
|
391 |
+
|
392 |
+
if is_bad:
|
393 |
+
output_tokens.append(self.unk_token)
|
394 |
+
else:
|
395 |
+
output_tokens.extend(sub_tokens)
|
396 |
+
return output_tokens
|
397 |
+
|
398 |
+
|
399 |
+
def _is_whitespace(char):
|
400 |
+
"""Checks whether `chars` is a whitespace character."""
|
401 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
402 |
+
# as whitespace since they are generally considered as such.
|
403 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
404 |
+
return True
|
405 |
+
cat = unicodedata.category(char)
|
406 |
+
if cat == "Zs":
|
407 |
+
return True
|
408 |
+
return False
|
409 |
+
|
410 |
+
|
411 |
+
def _is_control(char):
|
412 |
+
"""Checks whether `chars` is a control character."""
|
413 |
+
# These are technically control characters but we count them as whitespace
|
414 |
+
# characters.
|
415 |
+
if char == "\t" or char == "\n" or char == "\r":
|
416 |
+
return False
|
417 |
+
cat = unicodedata.category(char)
|
418 |
+
if cat in ("Cc", "Cf"):
|
419 |
+
return True
|
420 |
+
return False
|
421 |
+
|
422 |
+
|
423 |
+
def _is_punctuation(char):
|
424 |
+
"""Checks whether `chars` is a punctuation character."""
|
425 |
+
cp = ord(char)
|
426 |
+
# We treat all non-letter/number ASCII as punctuation.
|
427 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
428 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
429 |
+
# consistency.
|
430 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
431 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
432 |
+
return True
|
433 |
+
cat = unicodedata.category(char)
|
434 |
+
if cat.startswith("P"):
|
435 |
+
return True
|
436 |
+
return False
|
Model/CLIP/cn_clip/clip/configuration_bert.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" BERT model configuration """
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class BertConfig(object):
|
26 |
+
r"""
|
27 |
+
:class:`~transformers.BertConfig` is the configuration class to store the configuration of a
|
28 |
+
`BertModel`.
|
29 |
+
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
33 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
34 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
35 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
36 |
+
the Transformer encoder.
|
37 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
38 |
+
layer in the Transformer encoder.
|
39 |
+
hidden_act: The non-linear activation function (function or string) in the
|
40 |
+
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
41 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
42 |
+
layers in the embeddings, encoder, and pooler.
|
43 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
44 |
+
probabilities.
|
45 |
+
max_position_embeddings: The maximum sequence length that this model might
|
46 |
+
ever be used with. Typically set this to something large just in case
|
47 |
+
(e.g., 512 or 1024 or 2048).
|
48 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
49 |
+
`BertModel`.
|
50 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
51 |
+
initializing all weight matrices.
|
52 |
+
layer_norm_eps: The epsilon used by LayerNorm.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
vocab_size_or_config_json_file=30522,
|
57 |
+
hidden_size=768,
|
58 |
+
num_hidden_layers=12,
|
59 |
+
num_attention_heads=12,
|
60 |
+
intermediate_size=3072,
|
61 |
+
hidden_act="gelu",
|
62 |
+
hidden_dropout_prob=0.1,
|
63 |
+
attention_probs_dropout_prob=0.1,
|
64 |
+
max_position_embeddings=512,
|
65 |
+
type_vocab_size=2,
|
66 |
+
initializer_range=0.02,
|
67 |
+
layer_norm_eps=1e-12,
|
68 |
+
output_attentions=False,
|
69 |
+
output_hidden_states=False
|
70 |
+
):
|
71 |
+
self.vocab_size = vocab_size_or_config_json_file
|
72 |
+
self.hidden_size = hidden_size
|
73 |
+
self.num_hidden_layers = num_hidden_layers
|
74 |
+
self.num_attention_heads = num_attention_heads
|
75 |
+
self.hidden_act = hidden_act
|
76 |
+
self.intermediate_size = intermediate_size
|
77 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
78 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
79 |
+
self.max_position_embeddings = max_position_embeddings
|
80 |
+
self.type_vocab_size = type_vocab_size
|
81 |
+
self.initializer_range = initializer_range
|
82 |
+
self.layer_norm_eps = layer_norm_eps
|
83 |
+
self.output_attentions = output_attentions
|
84 |
+
self.output_hidden_states = output_hidden_states
|
Model/CLIP/cn_clip/clip/model.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
from itertools import repeat
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
import math
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from torch.utils.checkpoint import checkpoint
|
13 |
+
|
14 |
+
from cn_clip.clip import _tokenizer
|
15 |
+
from cn_clip.clip.configuration_bert import BertConfig
|
16 |
+
from cn_clip.clip.modeling_bert import BertModel
|
17 |
+
|
18 |
+
|
19 |
+
class Bottleneck(nn.Module):
|
20 |
+
expansion = 4
|
21 |
+
|
22 |
+
def __init__(self, inplanes, planes, stride=1):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
26 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
27 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
28 |
+
|
29 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
30 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
31 |
+
|
32 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
33 |
+
|
34 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
35 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
36 |
+
|
37 |
+
self.relu = nn.ReLU(inplace=True)
|
38 |
+
self.downsample = None
|
39 |
+
self.stride = stride
|
40 |
+
|
41 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
42 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
43 |
+
self.downsample = nn.Sequential(OrderedDict([
|
44 |
+
("-1", nn.AvgPool2d(stride)),
|
45 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
46 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
47 |
+
]))
|
48 |
+
|
49 |
+
def forward(self, x: torch.Tensor):
|
50 |
+
identity = x
|
51 |
+
|
52 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
53 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
54 |
+
out = self.avgpool(out)
|
55 |
+
out = self.bn3(self.conv3(out))
|
56 |
+
|
57 |
+
if self.downsample is not None:
|
58 |
+
identity = self.downsample(x)
|
59 |
+
|
60 |
+
out += identity
|
61 |
+
out = self.relu(out)
|
62 |
+
return out
|
63 |
+
|
64 |
+
|
65 |
+
class AttentionPool2d(nn.Module):
|
66 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
67 |
+
super().__init__()
|
68 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
69 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
70 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
71 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
72 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
73 |
+
self.num_heads = num_heads
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
77 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
78 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
79 |
+
x, _ = F.multi_head_attention_forward(
|
80 |
+
query=x, key=x, value=x,
|
81 |
+
embed_dim_to_check=x.shape[-1],
|
82 |
+
num_heads=self.num_heads,
|
83 |
+
q_proj_weight=self.q_proj.weight,
|
84 |
+
k_proj_weight=self.k_proj.weight,
|
85 |
+
v_proj_weight=self.v_proj.weight,
|
86 |
+
in_proj_weight=None,
|
87 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
88 |
+
bias_k=None,
|
89 |
+
bias_v=None,
|
90 |
+
add_zero_attn=False,
|
91 |
+
dropout_p=0,
|
92 |
+
out_proj_weight=self.c_proj.weight,
|
93 |
+
out_proj_bias=self.c_proj.bias,
|
94 |
+
use_separate_proj_weight=True,
|
95 |
+
training=self.training,
|
96 |
+
need_weights=False
|
97 |
+
)
|
98 |
+
|
99 |
+
return x[0]
|
100 |
+
|
101 |
+
|
102 |
+
class ModifiedResNet(nn.Module):
|
103 |
+
"""
|
104 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
105 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
106 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
107 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
111 |
+
super().__init__()
|
112 |
+
self.output_dim = output_dim
|
113 |
+
self.input_resolution = input_resolution
|
114 |
+
|
115 |
+
# the 3-layer stem
|
116 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
117 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
118 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
119 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
120 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
121 |
+
self.bn3 = nn.BatchNorm2d(width)
|
122 |
+
self.avgpool = nn.AvgPool2d(2)
|
123 |
+
self.relu = nn.ReLU(inplace=True)
|
124 |
+
|
125 |
+
# residual layers
|
126 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
127 |
+
self.layer1 = self._make_layer(width, layers[0])
|
128 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
129 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
130 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
131 |
+
|
132 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
133 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
134 |
+
|
135 |
+
def _make_layer(self, planes, blocks, stride=1):
|
136 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
137 |
+
|
138 |
+
self._inplanes = planes * Bottleneck.expansion
|
139 |
+
for _ in range(1, blocks):
|
140 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
141 |
+
|
142 |
+
return nn.Sequential(*layers)
|
143 |
+
|
144 |
+
@torch.jit.ignore
|
145 |
+
def set_grad_checkpointing(self, enable=True):
|
146 |
+
# FIXME support for non-transformer
|
147 |
+
pass
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
def stem(x):
|
151 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
152 |
+
x = self.relu(bn(conv(x)))
|
153 |
+
x = self.avgpool(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
x = x.type(self.conv1.weight.dtype)
|
157 |
+
x = stem(x)
|
158 |
+
x = self.layer1(x)
|
159 |
+
x = self.layer2(x)
|
160 |
+
x = self.layer3(x)
|
161 |
+
x = self.layer4(x)
|
162 |
+
x = self.attnpool(x)
|
163 |
+
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class LayerNorm(nn.LayerNorm):
|
168 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
169 |
+
|
170 |
+
def forward(self, x: torch.Tensor):
|
171 |
+
orig_type = x.dtype
|
172 |
+
ret = super().forward(x.type(torch.float32))
|
173 |
+
return ret.type(orig_type)
|
174 |
+
|
175 |
+
|
176 |
+
class QuickGELU(nn.Module):
|
177 |
+
def forward(self, x: torch.Tensor):
|
178 |
+
return x * torch.sigmoid(1.702 * x)
|
179 |
+
|
180 |
+
|
181 |
+
class ResidualAttentionBlock(nn.Module):
|
182 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
186 |
+
self.ln_1 = LayerNorm(d_model)
|
187 |
+
self.mlp = nn.Sequential(OrderedDict([
|
188 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
189 |
+
("gelu", QuickGELU()),
|
190 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
191 |
+
]))
|
192 |
+
self.ln_2 = LayerNorm(d_model)
|
193 |
+
self.attn_mask = attn_mask
|
194 |
+
|
195 |
+
def attention(self, x: torch.Tensor):
|
196 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
197 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
198 |
+
|
199 |
+
def forward(self, x: torch.Tensor):
|
200 |
+
x = x + self.attention(self.ln_1(x))
|
201 |
+
x = x + self.mlp(self.ln_2(x))
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class Transformer(nn.Module):
|
206 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
207 |
+
super().__init__()
|
208 |
+
self.width = width
|
209 |
+
self.layers = layers
|
210 |
+
self.grad_checkpointing = False
|
211 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
212 |
+
|
213 |
+
def forward(self, x: torch.Tensor):
|
214 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
215 |
+
for r in self.resblocks:
|
216 |
+
x = checkpoint(r, x)
|
217 |
+
return x
|
218 |
+
return self.resblocks(x)
|
219 |
+
|
220 |
+
|
221 |
+
class VisualTransformer(nn.Module):
|
222 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
223 |
+
super().__init__()
|
224 |
+
self.input_resolution = input_resolution
|
225 |
+
self.grid_size = (self.input_resolution // patch_size, self.input_resolution // patch_size)
|
226 |
+
self.output_dim = output_dim
|
227 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
228 |
+
|
229 |
+
scale = width ** -0.5
|
230 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
231 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
232 |
+
self.ln_pre = LayerNorm(width)
|
233 |
+
|
234 |
+
self.transformer = Transformer(width, layers, heads)
|
235 |
+
|
236 |
+
self.ln_post = LayerNorm(width)
|
237 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
238 |
+
|
239 |
+
@torch.jit.ignore
|
240 |
+
def set_grad_checkpointing(self, enable=True):
|
241 |
+
self.transformer.grad_checkpointing = enable
|
242 |
+
|
243 |
+
def forward(self, x: torch.Tensor):
|
244 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
245 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
246 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
247 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
248 |
+
x = x + self.positional_embedding.to(x.dtype)
|
249 |
+
x = self.ln_pre(x)
|
250 |
+
|
251 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
252 |
+
x = self.transformer(x)
|
253 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
254 |
+
|
255 |
+
x = self.ln_post(x[:, 0, :])
|
256 |
+
|
257 |
+
if self.proj is not None:
|
258 |
+
x = x @ self.proj
|
259 |
+
|
260 |
+
return x
|
261 |
+
|
262 |
+
|
263 |
+
class CLIP(nn.Module):
|
264 |
+
def __init__(self,
|
265 |
+
embed_dim: int,
|
266 |
+
# vision
|
267 |
+
image_resolution: int,
|
268 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
269 |
+
vision_width: int,
|
270 |
+
vision_patch_size: int,
|
271 |
+
# text
|
272 |
+
vocab_size: int,
|
273 |
+
text_attention_probs_dropout_prob: float,
|
274 |
+
text_hidden_act: str,
|
275 |
+
text_hidden_dropout_prob: float,
|
276 |
+
text_hidden_size: int,
|
277 |
+
text_initializer_range: float,
|
278 |
+
text_intermediate_size: int,
|
279 |
+
text_max_position_embeddings: int,
|
280 |
+
text_num_attention_heads: int,
|
281 |
+
text_num_hidden_layers: int,
|
282 |
+
text_type_vocab_size: int,
|
283 |
+
tokenizer = _tokenizer,
|
284 |
+
# vision head width, added this param for ViT-H
|
285 |
+
vision_head_width: int = 64,
|
286 |
+
):
|
287 |
+
super().__init__()
|
288 |
+
|
289 |
+
if isinstance(vision_layers, (tuple, list)):
|
290 |
+
vision_heads = vision_width * 32 // vision_head_width
|
291 |
+
self.visual = ModifiedResNet(
|
292 |
+
layers=vision_layers,
|
293 |
+
output_dim=embed_dim,
|
294 |
+
heads=vision_heads,
|
295 |
+
input_resolution=image_resolution,
|
296 |
+
width=vision_width
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
vision_heads = vision_width // vision_head_width
|
300 |
+
self.visual = VisualTransformer(
|
301 |
+
input_resolution=image_resolution,
|
302 |
+
patch_size=vision_patch_size,
|
303 |
+
width=vision_width,
|
304 |
+
layers=vision_layers,
|
305 |
+
heads=vision_heads,
|
306 |
+
output_dim=embed_dim
|
307 |
+
)
|
308 |
+
|
309 |
+
self.bert_config = BertConfig(
|
310 |
+
vocab_size_or_config_json_file=vocab_size,
|
311 |
+
hidden_size=text_hidden_size,
|
312 |
+
num_hidden_layers=text_num_hidden_layers,
|
313 |
+
num_attention_heads=text_num_attention_heads,
|
314 |
+
intermediate_size=text_intermediate_size,
|
315 |
+
hidden_act=text_hidden_act,
|
316 |
+
hidden_dropout_prob=text_hidden_dropout_prob,
|
317 |
+
attention_probs_dropout_prob=text_attention_probs_dropout_prob,
|
318 |
+
max_position_embeddings=text_max_position_embeddings,
|
319 |
+
type_vocab_size=text_type_vocab_size,
|
320 |
+
initializer_range=text_initializer_range,
|
321 |
+
layer_norm_eps=1e-12,
|
322 |
+
)
|
323 |
+
self.bert = BertModel(self.bert_config)
|
324 |
+
|
325 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
|
326 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
327 |
+
|
328 |
+
self.tokenizer = tokenizer
|
329 |
+
|
330 |
+
self.initialize_parameters()
|
331 |
+
|
332 |
+
def initialize_parameters(self):
|
333 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
334 |
+
|
335 |
+
if isinstance(self.visual, ModifiedResNet):
|
336 |
+
if self.visual.attnpool is not None:
|
337 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
338 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
339 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
340 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
341 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
342 |
+
|
343 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
344 |
+
for name, param in resnet_block.named_parameters():
|
345 |
+
if name.endswith("bn3.weight"):
|
346 |
+
nn.init.zeros_(param)
|
347 |
+
|
348 |
+
if self.text_projection is not None:
|
349 |
+
nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
|
350 |
+
|
351 |
+
@torch.jit.ignore
|
352 |
+
def set_grad_checkpointing(self, enable=True):
|
353 |
+
self.visual.set_grad_checkpointing(enable)
|
354 |
+
self.bert.set_grad_checkpointing(enable)
|
355 |
+
|
356 |
+
@property
|
357 |
+
def dtype(self):
|
358 |
+
return self.visual.conv1.weight.dtype
|
359 |
+
|
360 |
+
def encode_image(self, image):
|
361 |
+
return self.visual(image.type(self.dtype))
|
362 |
+
|
363 |
+
def encode_text(self, text):
|
364 |
+
pad_index = self.tokenizer.vocab['[PAD]']
|
365 |
+
attn_mask = text.ne(pad_index).type(self.dtype)
|
366 |
+
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
|
367 |
+
return x[:, 0, :] @ self.text_projection
|
368 |
+
|
369 |
+
def forward(self, image, text):
|
370 |
+
assert image is not None or text is not None, "text and image cannot both be None!"
|
371 |
+
|
372 |
+
if image is None:
|
373 |
+
return self.encode_text(text)
|
374 |
+
elif text is None:
|
375 |
+
return self.encode_image(image)
|
376 |
+
|
377 |
+
image_features = self.encode_image(image)
|
378 |
+
text_features = self.encode_text(text)
|
379 |
+
|
380 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
381 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
382 |
+
|
383 |
+
return image_features, text_features, self.logit_scale.exp()
|
384 |
+
|
385 |
+
def get_similarity(self, image, text):
|
386 |
+
image_features = self.encode_image(image)
|
387 |
+
text_features = self.encode_text(text)
|
388 |
+
|
389 |
+
# normalized features
|
390 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
391 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
392 |
+
|
393 |
+
# cosine similarity as logits
|
394 |
+
logit_scale = self.logit_scale.exp()
|
395 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
396 |
+
logits_per_text = logits_per_image.t()
|
397 |
+
|
398 |
+
# shape = [global_batch_size, global_batch_size]
|
399 |
+
return logits_per_image, logits_per_text
|
400 |
+
|
401 |
+
|
402 |
+
def convert_models_to_fp32(model):
|
403 |
+
for p in model.parameters():
|
404 |
+
p.data = p.data.float()
|
405 |
+
if p.grad:
|
406 |
+
p.grad.data = p.grad.data.float()
|
407 |
+
|
408 |
+
|
409 |
+
def convert_weights(model: nn.Module):
|
410 |
+
"""Convert applicable model parameters to fp16"""
|
411 |
+
|
412 |
+
def _convert_weights_to_fp16(l):
|
413 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
414 |
+
l.weight.data = l.weight.data.half()
|
415 |
+
if l.bias is not None:
|
416 |
+
l.bias.data = l.bias.data.half()
|
417 |
+
|
418 |
+
if isinstance(l, nn.MultiheadAttention):
|
419 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
420 |
+
tensor = getattr(l, attr)
|
421 |
+
if tensor is not None:
|
422 |
+
tensor.data = tensor.data.half()
|
423 |
+
|
424 |
+
if isinstance(l, BertModel):
|
425 |
+
l.to(torch.half)
|
426 |
+
|
427 |
+
for name in ["text_projection", "proj"]:
|
428 |
+
if hasattr(l, name):
|
429 |
+
attr = getattr(l, name)
|
430 |
+
if attr is not None:
|
431 |
+
attr.data = attr.data.half()
|
432 |
+
|
433 |
+
model.apply(_convert_weights_to_fp16)
|
434 |
+
|
435 |
+
|
436 |
+
def restore_model(model, clip_state_dict: dict, bert_state_dict: dict):
|
437 |
+
merged_state_dict = {}
|
438 |
+
|
439 |
+
# use clip_state_dict to initialize the image encoder & logit scale
|
440 |
+
if clip_state_dict is not None:
|
441 |
+
for k, v in clip_state_dict.items():
|
442 |
+
if k.startswith("visual") or k == "logit_scale":
|
443 |
+
merged_state_dict[k] = v
|
444 |
+
|
445 |
+
# use bert_state_dict to initialize the text encoder
|
446 |
+
if bert_state_dict is not None:
|
447 |
+
for k, v in bert_state_dict.items():
|
448 |
+
if k.startswith("bert") and "bert.pooler" not in k:
|
449 |
+
merged_state_dict[k] = v
|
450 |
+
|
451 |
+
convert_weights(model)
|
452 |
+
resize_pos_embed(merged_state_dict, model)
|
453 |
+
model.load_state_dict(merged_state_dict, strict=False)
|
454 |
+
return model.eval()
|
455 |
+
|
456 |
+
|
457 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1, prefix=""):
|
458 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
459 |
+
old_pos_embed = state_dict.get(prefix + 'visual.positional_embedding', None)
|
460 |
+
model = model.module if hasattr(model, 'module') else model
|
461 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
462 |
+
return
|
463 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
464 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
465 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
466 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
467 |
+
return
|
468 |
+
|
469 |
+
if extra_tokens:
|
470 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
471 |
+
else:
|
472 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
473 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
474 |
+
|
475 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
476 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
477 |
+
pos_emb_img = F.interpolate(
|
478 |
+
pos_emb_img,
|
479 |
+
size=grid_size,
|
480 |
+
mode=interpolation,
|
481 |
+
align_corners=True,
|
482 |
+
)
|
483 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
484 |
+
if pos_emb_tok is not None:
|
485 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
486 |
+
else:
|
487 |
+
new_pos_embed = pos_emb_img
|
488 |
+
state_dict[prefix + 'visual.positional_embedding'] = new_pos_embed
|
489 |
+
|
490 |
+
|
491 |
+
# From PyTorch internals
|
492 |
+
def _ntuple(n):
|
493 |
+
def parse(x):
|
494 |
+
if isinstance(x, collections.abc.Iterable):
|
495 |
+
return x
|
496 |
+
return tuple(repeat(x, n))
|
497 |
+
return parse
|
498 |
+
|
499 |
+
|
500 |
+
to_1tuple = _ntuple(1)
|
501 |
+
to_2tuple = _ntuple(2)
|
502 |
+
to_3tuple = _ntuple(3)
|
503 |
+
to_4tuple = _ntuple(4)
|
504 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
Model/CLIP/cn_clip/clip/model_configs/RBT3-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 768,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 3072,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 12,
|
11 |
+
"text_num_hidden_layers": 3,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": "[3,4,6,3]",
|
5 |
+
"vision_width": 64,
|
6 |
+
"vision_patch_size": null
|
7 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 768,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 3072,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 12,
|
11 |
+
"text_num_hidden_layers": 12,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/RoBERTa-wwm-ext-large-chinese.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 21128,
|
3 |
+
"text_attention_probs_dropout_prob": 0.1,
|
4 |
+
"text_hidden_act": "gelu",
|
5 |
+
"text_hidden_dropout_prob": 0.1,
|
6 |
+
"text_hidden_size": 1024,
|
7 |
+
"text_initializer_range": 0.02,
|
8 |
+
"text_intermediate_size": 4096,
|
9 |
+
"text_max_position_embeddings": 512,
|
10 |
+
"text_num_attention_heads": 16,
|
11 |
+
"text_num_hidden_layers": 24,
|
12 |
+
"text_type_vocab_size": 2
|
13 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 16
|
7 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 12,
|
5 |
+
"vision_width": 768,
|
6 |
+
"vision_patch_size": 32
|
7 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/ViT-H-14.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 32,
|
5 |
+
"vision_width": 1280,
|
6 |
+
"vision_head_width": 80,
|
7 |
+
"vision_patch_size": 14
|
8 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/ViT-L-14-336.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 336,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14
|
7 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"image_resolution": 224,
|
4 |
+
"vision_layers": 24,
|
5 |
+
"vision_width": 1024,
|
6 |
+
"vision_patch_size": 14
|
7 |
+
}
|
Model/CLIP/cn_clip/clip/model_configs/for_learn.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
|
5 |
+
vision_model = "ViT-B-16"
|
6 |
+
vision_model_config_file = \
|
7 |
+
Path(__file__).parent / f"{vision_model.replace('/', '-')}.json"
|
8 |
+
print('Loading vision model config from', vision_model_config_file)
|
9 |
+
assert os.path.exists(vision_model_config_file)
|
10 |
+
with open(vision_model_config_file, 'r') as fv:
|
11 |
+
model_info = json.load(fv).items()
|
12 |
+
print('Model info:', model_info)
|
13 |
+
if isinstance(model_info['vision_layers'], str):
|
14 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
15 |
+
|
16 |
+
|
Model/CLIP/cn_clip/clip/modeling_bert.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model. """
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
19 |
+
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
from io import open
|
26 |
+
|
27 |
+
import torch
|
28 |
+
from torch import nn
|
29 |
+
from torch.utils.checkpoint import checkpoint
|
30 |
+
|
31 |
+
from .configuration_bert import BertConfig
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
def gelu(x):
|
36 |
+
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
37 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
38 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
39 |
+
Also see https://arxiv.org/abs/1606.08415
|
40 |
+
"""
|
41 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
42 |
+
|
43 |
+
def gelu_new(x):
|
44 |
+
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
45 |
+
Also see https://arxiv.org/abs/1606.08415
|
46 |
+
"""
|
47 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
48 |
+
|
49 |
+
def swish(x):
|
50 |
+
return x * torch.sigmoid(x)
|
51 |
+
|
52 |
+
|
53 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
|
54 |
+
|
55 |
+
|
56 |
+
BertLayerNorm = torch.nn.LayerNorm
|
57 |
+
|
58 |
+
class BertEmbeddings(nn.Module):
|
59 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
60 |
+
"""
|
61 |
+
def __init__(self, config):
|
62 |
+
super(BertEmbeddings, self).__init__()
|
63 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
64 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
65 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
66 |
+
|
67 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
68 |
+
# any TensorFlow checkpoint file
|
69 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
70 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
71 |
+
|
72 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
73 |
+
seq_length = input_ids.size(1)
|
74 |
+
if position_ids is None:
|
75 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
76 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
77 |
+
if token_type_ids is None:
|
78 |
+
token_type_ids = torch.zeros_like(input_ids)
|
79 |
+
|
80 |
+
words_embeddings = self.word_embeddings(input_ids)
|
81 |
+
position_embeddings = self.position_embeddings(position_ids)
|
82 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
83 |
+
|
84 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
85 |
+
embeddings = self.LayerNorm(embeddings)
|
86 |
+
embeddings = self.dropout(embeddings)
|
87 |
+
return embeddings
|
88 |
+
|
89 |
+
|
90 |
+
class BertSelfAttention(nn.Module):
|
91 |
+
def __init__(self, config):
|
92 |
+
super(BertSelfAttention, self).__init__()
|
93 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
94 |
+
raise ValueError(
|
95 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
96 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
97 |
+
self.output_attentions = config.output_attentions
|
98 |
+
|
99 |
+
self.num_attention_heads = config.num_attention_heads
|
100 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
101 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
102 |
+
|
103 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
104 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
105 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
106 |
+
|
107 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
108 |
+
|
109 |
+
def transpose_for_scores(self, x):
|
110 |
+
|
111 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
112 |
+
x = x.view(*new_x_shape)
|
113 |
+
return x.permute(0, 2, 1, 3)
|
114 |
+
|
115 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
116 |
+
mixed_query_layer = self.query(hidden_states)
|
117 |
+
mixed_key_layer = self.key(hidden_states)
|
118 |
+
mixed_value_layer = self.value(hidden_states)
|
119 |
+
|
120 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
121 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
122 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
123 |
+
|
124 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
125 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
126 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
127 |
+
if attention_mask is not None:
|
128 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
129 |
+
attention_scores = attention_scores + attention_mask
|
130 |
+
|
131 |
+
# Normalize the attention scores to probabilities.
|
132 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
133 |
+
|
134 |
+
# This is actually dropping out entire tokens to attend to, which might
|
135 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
136 |
+
attention_probs = self.dropout(attention_probs)
|
137 |
+
|
138 |
+
# Mask heads if we want to
|
139 |
+
if head_mask is not None:
|
140 |
+
attention_probs = attention_probs * head_mask
|
141 |
+
|
142 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
143 |
+
|
144 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
145 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
146 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
147 |
+
|
148 |
+
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
149 |
+
return outputs
|
150 |
+
|
151 |
+
|
152 |
+
class BertSelfOutput(nn.Module):
|
153 |
+
def __init__(self, config):
|
154 |
+
super(BertSelfOutput, self).__init__()
|
155 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
156 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
157 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
158 |
+
|
159 |
+
def forward(self, hidden_states, input_tensor):
|
160 |
+
hidden_states = self.dense(hidden_states)
|
161 |
+
hidden_states = self.dropout(hidden_states)
|
162 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
163 |
+
return hidden_states
|
164 |
+
|
165 |
+
|
166 |
+
class BertAttention(nn.Module):
|
167 |
+
def __init__(self, config):
|
168 |
+
super(BertAttention, self).__init__()
|
169 |
+
self.self = BertSelfAttention(config)
|
170 |
+
self.output = BertSelfOutput(config)
|
171 |
+
self.pruned_heads = set()
|
172 |
+
|
173 |
+
def forward(self, input_tensor, attention_mask=None, head_mask=None):
|
174 |
+
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
175 |
+
attention_output = self.output(self_outputs[0], input_tensor)
|
176 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
177 |
+
return outputs
|
178 |
+
|
179 |
+
|
180 |
+
class BertIntermediate(nn.Module):
|
181 |
+
def __init__(self, config):
|
182 |
+
super(BertIntermediate, self).__init__()
|
183 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
184 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
185 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
186 |
+
else:
|
187 |
+
self.intermediate_act_fn = config.hidden_act
|
188 |
+
|
189 |
+
def forward(self, hidden_states):
|
190 |
+
hidden_states = self.dense(hidden_states)
|
191 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
192 |
+
return hidden_states
|
193 |
+
|
194 |
+
|
195 |
+
class BertOutput(nn.Module):
|
196 |
+
def __init__(self, config):
|
197 |
+
super(BertOutput, self).__init__()
|
198 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
199 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
200 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
201 |
+
|
202 |
+
def forward(self, hidden_states, input_tensor):
|
203 |
+
hidden_states = self.dense(hidden_states)
|
204 |
+
hidden_states = self.dropout(hidden_states)
|
205 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
206 |
+
return hidden_states
|
207 |
+
|
208 |
+
|
209 |
+
class BertLayer(nn.Module):
|
210 |
+
def __init__(self, config):
|
211 |
+
super(BertLayer, self).__init__()
|
212 |
+
self.attention = BertAttention(config)
|
213 |
+
self.intermediate = BertIntermediate(config)
|
214 |
+
self.output = BertOutput(config)
|
215 |
+
|
216 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
217 |
+
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
218 |
+
attention_output = attention_outputs[0]
|
219 |
+
intermediate_output = self.intermediate(attention_output)
|
220 |
+
layer_output = self.output(intermediate_output, attention_output)
|
221 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
222 |
+
if len(outputs) == 1:
|
223 |
+
return outputs[0]
|
224 |
+
return outputs
|
225 |
+
|
226 |
+
|
227 |
+
class BertEncoder(nn.Module):
|
228 |
+
def __init__(self, config):
|
229 |
+
super(BertEncoder, self).__init__()
|
230 |
+
self.output_attentions = config.output_attentions
|
231 |
+
self.output_hidden_states = config.output_hidden_states
|
232 |
+
self.grad_checkpointing = False
|
233 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
234 |
+
|
235 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
236 |
+
all_hidden_states = ()
|
237 |
+
all_attentions = ()
|
238 |
+
for i, layer_module in enumerate(self.layer):
|
239 |
+
if self.output_hidden_states:
|
240 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
241 |
+
|
242 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
243 |
+
layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i])
|
244 |
+
else:
|
245 |
+
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
246 |
+
if not isinstance(layer_outputs, tuple):
|
247 |
+
layer_outputs = (layer_outputs, )
|
248 |
+
hidden_states = layer_outputs[0]
|
249 |
+
|
250 |
+
if self.output_attentions:
|
251 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
252 |
+
|
253 |
+
# Add last layer
|
254 |
+
if self.output_hidden_states:
|
255 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
256 |
+
|
257 |
+
outputs = (hidden_states,)
|
258 |
+
if self.output_hidden_states:
|
259 |
+
outputs = outputs + (all_hidden_states,)
|
260 |
+
if self.output_attentions:
|
261 |
+
outputs = outputs + (all_attentions,)
|
262 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
263 |
+
|
264 |
+
|
265 |
+
class BertPooler(nn.Module):
|
266 |
+
def __init__(self, config):
|
267 |
+
super(BertPooler, self).__init__()
|
268 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
269 |
+
self.activation = nn.Tanh()
|
270 |
+
|
271 |
+
def forward(self, hidden_states):
|
272 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
273 |
+
# to the first token.
|
274 |
+
first_token_tensor = hidden_states[:, 0]
|
275 |
+
pooled_output = self.dense(first_token_tensor)
|
276 |
+
pooled_output = self.activation(pooled_output)
|
277 |
+
return pooled_output
|
278 |
+
|
279 |
+
|
280 |
+
class BertPredictionHeadTransform(nn.Module):
|
281 |
+
def __init__(self, config):
|
282 |
+
super(BertPredictionHeadTransform, self).__init__()
|
283 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
284 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
285 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
286 |
+
else:
|
287 |
+
self.transform_act_fn = config.hidden_act
|
288 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
289 |
+
|
290 |
+
def forward(self, hidden_states):
|
291 |
+
hidden_states = self.dense(hidden_states)
|
292 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
293 |
+
hidden_states = self.LayerNorm(hidden_states)
|
294 |
+
return hidden_states
|
295 |
+
|
296 |
+
|
297 |
+
class BertLMPredictionHead(nn.Module):
|
298 |
+
def __init__(self, config):
|
299 |
+
super(BertLMPredictionHead, self).__init__()
|
300 |
+
self.transform = BertPredictionHeadTransform(config)
|
301 |
+
|
302 |
+
# The output weights are the same as the input embeddings, but there is
|
303 |
+
# an output-only bias for each token.
|
304 |
+
self.decoder = nn.Linear(config.hidden_size,
|
305 |
+
config.vocab_size,
|
306 |
+
bias=False)
|
307 |
+
|
308 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
309 |
+
|
310 |
+
def forward(self, hidden_states):
|
311 |
+
hidden_states = self.transform(hidden_states)
|
312 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
313 |
+
return hidden_states
|
314 |
+
|
315 |
+
|
316 |
+
class BertOnlyMLMHead(nn.Module):
|
317 |
+
def __init__(self, config):
|
318 |
+
super(BertOnlyMLMHead, self).__init__()
|
319 |
+
self.predictions = BertLMPredictionHead(config)
|
320 |
+
|
321 |
+
def forward(self, sequence_output):
|
322 |
+
prediction_scores = self.predictions(sequence_output)
|
323 |
+
return prediction_scores
|
324 |
+
|
325 |
+
|
326 |
+
class BertOnlyNSPHead(nn.Module):
|
327 |
+
def __init__(self, config):
|
328 |
+
super(BertOnlyNSPHead, self).__init__()
|
329 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
330 |
+
|
331 |
+
def forward(self, pooled_output):
|
332 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
333 |
+
return seq_relationship_score
|
334 |
+
|
335 |
+
|
336 |
+
class BertPreTrainingHeads(nn.Module):
|
337 |
+
def __init__(self, config):
|
338 |
+
super(BertPreTrainingHeads, self).__init__()
|
339 |
+
self.predictions = BertLMPredictionHead(config)
|
340 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
341 |
+
|
342 |
+
def forward(self, sequence_output, pooled_output):
|
343 |
+
prediction_scores = self.predictions(sequence_output)
|
344 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
345 |
+
return prediction_scores, seq_relationship_score
|
346 |
+
|
347 |
+
|
348 |
+
class BertPreTrainedModel(nn.Module):
|
349 |
+
config_class = BertConfig
|
350 |
+
base_model_prefix = "bert"
|
351 |
+
|
352 |
+
def __init__(self, config):
|
353 |
+
super(BertPreTrainedModel, self).__init__()
|
354 |
+
self.config = config
|
355 |
+
|
356 |
+
def _init_weights(self, module):
|
357 |
+
""" Initialize the weights """
|
358 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
359 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
360 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
361 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
362 |
+
elif isinstance(module, BertLayerNorm):
|
363 |
+
module.bias.data.zero_()
|
364 |
+
module.weight.data.fill_(1.0)
|
365 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
366 |
+
module.bias.data.zero_()
|
367 |
+
|
368 |
+
|
369 |
+
class BertModel(BertPreTrainedModel):
|
370 |
+
r"""
|
371 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
372 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
373 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
374 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
375 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
376 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
377 |
+
layer weights are trained from the next sentence prediction (classification)
|
378 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
379 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
380 |
+
the sequence of hidden-states for the whole input sequence.
|
381 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
382 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
383 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
384 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
385 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
386 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
387 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
388 |
+
|
389 |
+
Examples::
|
390 |
+
|
391 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
392 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
393 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
394 |
+
outputs = model(input_ids)
|
395 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
396 |
+
|
397 |
+
"""
|
398 |
+
def __init__(self, config):
|
399 |
+
super(BertModel, self).__init__(config)
|
400 |
+
|
401 |
+
self.embeddings = BertEmbeddings(config)
|
402 |
+
self.encoder = BertEncoder(config)
|
403 |
+
# self.pooler = BertPooler(config)
|
404 |
+
|
405 |
+
self.apply(self._init_weights)
|
406 |
+
|
407 |
+
@torch.jit.ignore
|
408 |
+
def set_grad_checkpointing(self, enable=True):
|
409 |
+
if enable:
|
410 |
+
assert not self.config.output_attentions, \
|
411 |
+
"Grad checkpointing is currently conflict with output_attentions for BertEncoder, \
|
412 |
+
please set it to False in BertConfig"
|
413 |
+
self.encoder.grad_checkpointing = enable
|
414 |
+
|
415 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
416 |
+
if attention_mask is None:
|
417 |
+
attention_mask = torch.ones_like(input_ids)
|
418 |
+
if token_type_ids is None:
|
419 |
+
token_type_ids = torch.zeros_like(input_ids)
|
420 |
+
|
421 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
422 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
423 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
424 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
425 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
426 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
427 |
+
|
428 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
429 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
430 |
+
# positions we want to attend and -10000.0 for masked positions.
|
431 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
432 |
+
# effectively the same as removing these entirely.
|
433 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
434 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
435 |
+
|
436 |
+
# Prepare head mask if needed
|
437 |
+
# 1.0 in head_mask indicate we keep the head
|
438 |
+
# attention_probs has shape bsz x n_heads x N x N
|
439 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
440 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
441 |
+
if head_mask is not None:
|
442 |
+
if head_mask.dim() == 1:
|
443 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
444 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
445 |
+
elif head_mask.dim() == 2:
|
446 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
447 |
+
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
448 |
+
else:
|
449 |
+
head_mask = [None] * self.config.num_hidden_layers
|
450 |
+
|
451 |
+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
452 |
+
encoder_outputs = self.encoder(embedding_output,
|
453 |
+
extended_attention_mask,
|
454 |
+
head_mask=head_mask)
|
455 |
+
sequence_output = encoder_outputs[0]
|
456 |
+
# pooled_output = self.pooler(sequence_output)
|
457 |
+
pooled_output = None
|
458 |
+
|
459 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
460 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
Model/CLIP/cn_clip/clip/utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code modified from https://github.com/openai/CLIP
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Union, List
|
7 |
+
import urllib
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from cn_clip.clip import _tokenizer
|
14 |
+
from cn_clip.clip.model import convert_weights, CLIP, restore_model
|
15 |
+
|
16 |
+
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
|
17 |
+
|
18 |
+
_MODELS = {
|
19 |
+
"ViT-B-16": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt",
|
20 |
+
"ViT-L-14": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14.pt",
|
21 |
+
"ViT-L-14-336": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14-336.pt",
|
22 |
+
"ViT-H-14": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-h-14.pt",
|
23 |
+
"RN50": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_rn50.pt",
|
24 |
+
}
|
25 |
+
_MODEL_INFO = {
|
26 |
+
"ViT-B-16": {
|
27 |
+
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
|
28 |
+
"input_resolution": 224
|
29 |
+
},
|
30 |
+
"ViT-L-14": {
|
31 |
+
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
|
32 |
+
"input_resolution": 224
|
33 |
+
},
|
34 |
+
"ViT-L-14-336": {
|
35 |
+
"struct": "ViT-L-14-336@RoBERTa-wwm-ext-base-chinese",
|
36 |
+
"input_resolution": 336
|
37 |
+
},
|
38 |
+
"ViT-H-14": {
|
39 |
+
"struct": "ViT-H-14@RoBERTa-wwm-ext-large-chinese",
|
40 |
+
"input_resolution": 224
|
41 |
+
},
|
42 |
+
"RN50": {
|
43 |
+
"struct": "RN50@RBT3-chinese",
|
44 |
+
"input_resolution": 224
|
45 |
+
},
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
def _download(url: str, root: str):
|
50 |
+
os.makedirs(root, exist_ok=True)
|
51 |
+
filename = os.path.basename(url)
|
52 |
+
|
53 |
+
download_target = os.path.join(root, filename)
|
54 |
+
|
55 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
56 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
57 |
+
|
58 |
+
if os.path.isfile(download_target):
|
59 |
+
return download_target
|
60 |
+
|
61 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
62 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
63 |
+
unit_divisor=1024) as loop:
|
64 |
+
while True:
|
65 |
+
buffer = source.read(8192)
|
66 |
+
if not buffer:
|
67 |
+
break
|
68 |
+
|
69 |
+
output.write(buffer)
|
70 |
+
loop.update(len(buffer))
|
71 |
+
|
72 |
+
return download_target
|
73 |
+
|
74 |
+
|
75 |
+
def _convert_image_to_rgb(image):
|
76 |
+
return image.convert("RGB")
|
77 |
+
|
78 |
+
|
79 |
+
def available_models() -> List[str]:
|
80 |
+
"""Returns the names of available CLIP models"""
|
81 |
+
return list(_MODELS.keys())
|
82 |
+
|
83 |
+
|
84 |
+
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
85 |
+
download_root: str = None, resume: str = None):
|
86 |
+
if resume is not None:
|
87 |
+
model_path = resume
|
88 |
+
elif name in _MODELS:
|
89 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
90 |
+
elif os.path.isfile(name):
|
91 |
+
model_path = name
|
92 |
+
else:
|
93 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
94 |
+
|
95 |
+
with open(model_path, 'rb') as opened_file:
|
96 |
+
# loading saved checkpoint
|
97 |
+
checkpoint = torch.load(opened_file, map_location="cpu")
|
98 |
+
|
99 |
+
model = create_model(_MODEL_INFO[name]['struct'], checkpoint)
|
100 |
+
if str(device) == "cpu":
|
101 |
+
model.float()
|
102 |
+
else:
|
103 |
+
model.to(device)
|
104 |
+
|
105 |
+
return model, image_transform(_MODEL_INFO[name]['input_resolution'])
|
106 |
+
|
107 |
+
|
108 |
+
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
|
109 |
+
bert_path=None):
|
110 |
+
"""Load CLIP and BERT model weights
|
111 |
+
"""
|
112 |
+
|
113 |
+
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
|
114 |
+
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
|
115 |
+
|
116 |
+
restore_model(model, clip_state_dict, bert_state_dict).to(device)
|
117 |
+
|
118 |
+
if str(device) == "cpu":
|
119 |
+
model.float()
|
120 |
+
return model
|
121 |
+
|
122 |
+
|
123 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 64) -> torch.LongTensor:
|
124 |
+
"""
|
125 |
+
Returns the tokenized representation of given input string(s)
|
126 |
+
Parameters
|
127 |
+
----------
|
128 |
+
texts : Union[str, List[str]]
|
129 |
+
An input string or a list of input strings to tokenize
|
130 |
+
context_length : int
|
131 |
+
The context length to use; all baseline models use 24 as the context length
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
135 |
+
"""
|
136 |
+
if isinstance(texts, str):
|
137 |
+
texts = [texts]
|
138 |
+
|
139 |
+
all_tokens = []
|
140 |
+
for text in texts:
|
141 |
+
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
|
142 |
+
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
|
143 |
+
|
144 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
145 |
+
|
146 |
+
for i, tokens in enumerate(all_tokens):
|
147 |
+
assert len(tokens) <= context_length
|
148 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
149 |
+
|
150 |
+
return result
|
151 |
+
|
152 |
+
|
153 |
+
def _convert_to_rgb(image):
|
154 |
+
return image.convert('RGB')
|
155 |
+
|
156 |
+
|
157 |
+
def image_transform(image_size=224):
|
158 |
+
transform = Compose([
|
159 |
+
_convert_to_rgb,
|
160 |
+
Resize((image_size, image_size)),
|
161 |
+
ToTensor(),
|
162 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
163 |
+
])
|
164 |
+
return transform
|
165 |
+
|
166 |
+
|
167 |
+
def create_model(model_name, checkpoint=None):
|
168 |
+
vision_model, text_model = model_name.split('@')
|
169 |
+
# Initialize the model.
|
170 |
+
vision_model_config_file = Path(
|
171 |
+
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
|
172 |
+
# print('Loading vision model config from', vision_model_config_file)
|
173 |
+
assert os.path.exists(vision_model_config_file)
|
174 |
+
|
175 |
+
text_model_config_file = Path(
|
176 |
+
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
|
177 |
+
# print('Loading text model config from', text_model_config_file)
|
178 |
+
assert os.path.exists(text_model_config_file)
|
179 |
+
|
180 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
181 |
+
model_info = json.load(fv)
|
182 |
+
for k, v in json.load(ft).items():
|
183 |
+
model_info[k] = v
|
184 |
+
if isinstance(model_info['vision_layers'], str):
|
185 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
186 |
+
# print('Model info', model_info)
|
187 |
+
model = CLIP(**model_info)
|
188 |
+
convert_weights(model)
|
189 |
+
if checkpoint:
|
190 |
+
sd = checkpoint["state_dict"]
|
191 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
192 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
193 |
+
model.load_state_dict(sd)
|
194 |
+
return model
|
195 |
+
|
196 |
+
|
Model/CLIP/cn_clip/clip/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Model/CLIP/cn_clip/eval/__init__.py
ADDED
File without changes
|
Model/CLIP/cn_clip/eval/data.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
import base64
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
import lmdb
|
12 |
+
|
13 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
+
from torch.utils.data.sampler import SequentialSampler
|
19 |
+
|
20 |
+
import torchvision.datasets as datasets
|
21 |
+
|
22 |
+
from cn_clip.clip import tokenize
|
23 |
+
|
24 |
+
def _convert_to_rgb(image):
|
25 |
+
return image.convert('RGB')
|
26 |
+
|
27 |
+
|
28 |
+
def _preprocess_text(text):
|
29 |
+
# adapt the text to Chinese BERT vocab
|
30 |
+
text = text.lower().replace("“", "\"").replace("”", "\"")
|
31 |
+
return text
|
32 |
+
|
33 |
+
class EvalTxtDataset(Dataset):
|
34 |
+
def __init__(self, jsonl_filename, max_txt_length=24):
|
35 |
+
assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename)
|
36 |
+
|
37 |
+
logging.debug(f'Loading jsonl data from {jsonl_filename}.')
|
38 |
+
self.texts = []
|
39 |
+
with open(jsonl_filename, "r") as fin:
|
40 |
+
for line in fin:
|
41 |
+
obj = json.loads(line.strip())
|
42 |
+
text_id = obj['text_id']
|
43 |
+
text = obj['text']
|
44 |
+
self.texts.append((text_id, text))
|
45 |
+
logging.debug(f'Finished loading jsonl data from {jsonl_filename}.')
|
46 |
+
|
47 |
+
self.max_txt_length = max_txt_length
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.texts)
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
text_id, text = self.texts[idx]
|
54 |
+
text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0]
|
55 |
+
return text_id, text
|
56 |
+
|
57 |
+
class EvalImgDataset(Dataset):
|
58 |
+
def __init__(self, lmdb_imgs, resolution=224):
|
59 |
+
assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs)
|
60 |
+
|
61 |
+
logging.debug(f'Loading image LMDB from {lmdb_imgs}.')
|
62 |
+
|
63 |
+
self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
|
64 |
+
self.txn_imgs = self.env_imgs.begin(buffers=True)
|
65 |
+
self.cursor_imgs = self.txn_imgs.cursor()
|
66 |
+
self.iter_imgs = iter(self.cursor_imgs)
|
67 |
+
self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8'))
|
68 |
+
logging.info("The specified LMDB directory contains {} images.".format(self.number_images))
|
69 |
+
|
70 |
+
self.transform = self._build_transform(resolution)
|
71 |
+
|
72 |
+
def _build_transform(self, resolution):
|
73 |
+
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
74 |
+
return Compose([
|
75 |
+
Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC),
|
76 |
+
_convert_to_rgb,
|
77 |
+
ToTensor(),
|
78 |
+
normalize,
|
79 |
+
])
|
80 |
+
|
81 |
+
def __len__(self):
|
82 |
+
return self.number_images
|
83 |
+
|
84 |
+
def __getitem__(self, idx):
|
85 |
+
img_id, image_b64 = next(self.iter_imgs)
|
86 |
+
if img_id == b"num_images":
|
87 |
+
img_id, image_b64 = next(self.iter_imgs)
|
88 |
+
|
89 |
+
img_id = img_id.tobytes()
|
90 |
+
image_b64 = image_b64.tobytes()
|
91 |
+
|
92 |
+
img_id = int(img_id.decode(encoding="utf8", errors="ignore"))
|
93 |
+
image_b64 = image_b64.decode(encoding="utf8", errors="ignore")
|
94 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) # already resized
|
95 |
+
image = self.transform(image)
|
96 |
+
|
97 |
+
return img_id, image
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class DataInfo:
|
101 |
+
dataloader: DataLoader
|
102 |
+
sampler: DistributedSampler
|
103 |
+
|
104 |
+
def get_eval_txt_dataset(args, max_txt_length=24):
|
105 |
+
input_filename = args.text_data
|
106 |
+
dataset = EvalTxtDataset(
|
107 |
+
input_filename,
|
108 |
+
max_txt_length=max_txt_length)
|
109 |
+
num_samples = len(dataset)
|
110 |
+
sampler = SequentialSampler(dataset)
|
111 |
+
|
112 |
+
dataloader = DataLoader(
|
113 |
+
dataset,
|
114 |
+
batch_size=args.text_batch_size,
|
115 |
+
num_workers=0,
|
116 |
+
pin_memory=True,
|
117 |
+
sampler=sampler,
|
118 |
+
drop_last=False,
|
119 |
+
)
|
120 |
+
dataloader.num_samples = num_samples
|
121 |
+
dataloader.num_batches = len(dataloader)
|
122 |
+
|
123 |
+
return DataInfo(dataloader, sampler)
|
124 |
+
|
125 |
+
def fetch_resolution(vision_model):
|
126 |
+
# fetch the resolution from the vision model config
|
127 |
+
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json"
|
128 |
+
with open(vision_model_config_file, 'r') as fv:
|
129 |
+
model_info = json.load(fv)
|
130 |
+
return model_info["image_resolution"]
|
131 |
+
|
132 |
+
def get_eval_img_dataset(args):
|
133 |
+
lmdb_imgs = args.image_data
|
134 |
+
dataset = EvalImgDataset(
|
135 |
+
lmdb_imgs, resolution=fetch_resolution(args.vision_model))
|
136 |
+
num_samples = len(dataset)
|
137 |
+
sampler = SequentialSampler(dataset)
|
138 |
+
|
139 |
+
dataloader = DataLoader(
|
140 |
+
dataset,
|
141 |
+
batch_size=args.img_batch_size,
|
142 |
+
num_workers=0,
|
143 |
+
pin_memory=True,
|
144 |
+
sampler=sampler,
|
145 |
+
drop_last=False,
|
146 |
+
)
|
147 |
+
dataloader.num_samples = num_samples
|
148 |
+
dataloader.num_batches = len(dataloader)
|
149 |
+
|
150 |
+
return DataInfo(dataloader, sampler)
|
151 |
+
|
152 |
+
def get_imagenet_dataset(args, preprocess_fn, split):
|
153 |
+
assert split in ["val"]
|
154 |
+
|
155 |
+
data_path = args.imagenet_val
|
156 |
+
assert data_path
|
157 |
+
|
158 |
+
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
|
159 |
+
|
160 |
+
dataloader = torch.utils.data.DataLoader(
|
161 |
+
dataset,
|
162 |
+
batch_size=args.img_batch_size,
|
163 |
+
num_workers=args.num_workers,
|
164 |
+
sampler=None,
|
165 |
+
)
|
166 |
+
|
167 |
+
return DataInfo(dataloader, None)
|
Model/CLIP/cn_clip/eval/evaluation.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script computes the recall scores given the ground-truth annotations and predictions.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import json
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import string
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
|
13 |
+
NUM_K = 10
|
14 |
+
|
15 |
+
def read_submission(submit_path, reference, k=5):
|
16 |
+
# check whether the path of submitted file exists
|
17 |
+
if not os.path.exists(submit_path):
|
18 |
+
raise Exception("The submission file is not found!")
|
19 |
+
|
20 |
+
submission_dict = {}
|
21 |
+
ref_qids = set(reference.keys())
|
22 |
+
|
23 |
+
with open(submit_path) as fin:
|
24 |
+
for line in fin:
|
25 |
+
line = line.strip()
|
26 |
+
try:
|
27 |
+
pred_obj = json.loads(line)
|
28 |
+
except:
|
29 |
+
raise Exception('Cannot parse this line into json object: {}'.format(line))
|
30 |
+
if "text_id" not in pred_obj:
|
31 |
+
raise Exception('There exists one line not containing text_id: {}'.format(line))
|
32 |
+
if not isinstance(pred_obj['text_id'], int):
|
33 |
+
raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
|
34 |
+
qid = pred_obj["text_id"]
|
35 |
+
if "image_ids" not in pred_obj:
|
36 |
+
raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
|
37 |
+
image_ids = pred_obj["image_ids"]
|
38 |
+
if not isinstance(image_ids, list):
|
39 |
+
raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
|
40 |
+
# check whether there are K products for each text
|
41 |
+
if len(image_ids) != k:
|
42 |
+
raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))
|
43 |
+
# check whether there exist an invalid prediction for any text
|
44 |
+
for rank, image_id in enumerate(image_ids):
|
45 |
+
if not isinstance(image_id, int):
|
46 |
+
raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
|
47 |
+
# check whether there are duplicate predicted products for a single text
|
48 |
+
if len(set(image_ids)) != k:
|
49 |
+
raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
|
50 |
+
submission_dict[qid] = image_ids # here we save the list of product ids
|
51 |
+
|
52 |
+
# check if any text is missing in the submission
|
53 |
+
pred_qids = set(submission_dict.keys())
|
54 |
+
nopred_qids = ref_qids - pred_qids
|
55 |
+
if len(nopred_qids) != 0:
|
56 |
+
raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))
|
57 |
+
|
58 |
+
return submission_dict
|
59 |
+
|
60 |
+
|
61 |
+
def dump_2_json(info, path):
|
62 |
+
with open(path, 'w') as output_json_file:
|
63 |
+
json.dump(info, output_json_file)
|
64 |
+
|
65 |
+
|
66 |
+
def report_error_msg(detail, showMsg, out_p):
|
67 |
+
error_dict=dict()
|
68 |
+
error_dict['errorDetail']=detail
|
69 |
+
error_dict['errorMsg']=showMsg
|
70 |
+
error_dict['score']=0
|
71 |
+
error_dict['scoreJson']={}
|
72 |
+
error_dict['success']=False
|
73 |
+
dump_2_json(error_dict,out_p)
|
74 |
+
|
75 |
+
|
76 |
+
def report_score(r1, r5, r10, out_p):
|
77 |
+
result = dict()
|
78 |
+
result['success']=True
|
79 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
80 |
+
result['score'] = mean_recall * 100
|
81 |
+
result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
|
82 |
+
dump_2_json(result,out_p)
|
83 |
+
|
84 |
+
|
85 |
+
def read_reference(path):
|
86 |
+
fin = open(path)
|
87 |
+
reference = dict()
|
88 |
+
for line in fin:
|
89 |
+
line = line.strip()
|
90 |
+
obj = json.loads(line)
|
91 |
+
reference[obj['text_id']] = obj['image_ids']
|
92 |
+
return reference
|
93 |
+
|
94 |
+
def compute_score(golden_file, predict_file):
|
95 |
+
# read ground-truth
|
96 |
+
reference = read_reference(golden_file)
|
97 |
+
|
98 |
+
# read predictions
|
99 |
+
k = 10
|
100 |
+
predictions = read_submission(predict_file, reference, k)
|
101 |
+
|
102 |
+
# compute score for each text
|
103 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
104 |
+
for qid in reference.keys():
|
105 |
+
ground_truth_ids = set(reference[qid])
|
106 |
+
top10_pred_ids = predictions[qid]
|
107 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
108 |
+
r1_stat += 1
|
109 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
110 |
+
r5_stat += 1
|
111 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
112 |
+
r10_stat += 1
|
113 |
+
# the higher score, the better
|
114 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
115 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
116 |
+
result = [mean_recall, r1, r5, r10]
|
117 |
+
result = [score * 100 for score in result]
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
if __name__=="__main__":
|
122 |
+
# the path of answer json file (eg. test_queries_answers.jsonl)
|
123 |
+
standard_path = sys.argv[1]
|
124 |
+
# the path of prediction file (eg. example_pred.jsonl)
|
125 |
+
submit_path = sys.argv[2]
|
126 |
+
# the score will be dumped into this output json file
|
127 |
+
out_path = sys.argv[3]
|
128 |
+
|
129 |
+
print("Read standard from %s" % standard_path)
|
130 |
+
print("Read user submit file from %s" % submit_path)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# read ground-truth
|
134 |
+
reference = read_reference(standard_path)
|
135 |
+
|
136 |
+
# read predictions
|
137 |
+
k = 10
|
138 |
+
predictions = read_submission(submit_path, reference, k)
|
139 |
+
|
140 |
+
# compute score for each text
|
141 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
142 |
+
for qid in reference.keys():
|
143 |
+
ground_truth_ids = set(reference[qid])
|
144 |
+
top10_pred_ids = predictions[qid]
|
145 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
146 |
+
r1_stat += 1
|
147 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
148 |
+
r5_stat += 1
|
149 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
150 |
+
r10_stat += 1
|
151 |
+
# the higher score, the better
|
152 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
153 |
+
report_score(r1, r5, r10, out_path)
|
154 |
+
print("The evaluation finished successfully.")
|
155 |
+
except Exception as e:
|
156 |
+
report_error_msg(e.args[0], e.args[0], out_path)
|
157 |
+
print("The evaluation failed: {}".format(e.args[0]))
|
Model/CLIP/cn_clip/eval/evaluation_tr.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script computes the recall scores given the ground-truth annotations and predictions.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import json
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import string
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
|
13 |
+
NUM_K = 10
|
14 |
+
|
15 |
+
def read_submission(submit_path, reference, k=5):
|
16 |
+
# check whether the path of submitted file exists
|
17 |
+
if not os.path.exists(submit_path):
|
18 |
+
raise Exception("The submission file is not found!")
|
19 |
+
|
20 |
+
submission_dict = {}
|
21 |
+
ref_image_ids = set(reference.keys())
|
22 |
+
|
23 |
+
with open(submit_path) as fin:
|
24 |
+
for line in fin:
|
25 |
+
line = line.strip()
|
26 |
+
try:
|
27 |
+
pred_obj = json.loads(line)
|
28 |
+
except:
|
29 |
+
raise Exception('Cannot parse this line into json object: {}'.format(line))
|
30 |
+
if "image_id" not in pred_obj:
|
31 |
+
raise Exception('There exists one line not containing image_id: {}'.format(line))
|
32 |
+
if not isinstance(pred_obj['image_id'], int):
|
33 |
+
raise Exception('Found an invalid image_id {}, it should be an integer (not string), please check your schema'.format(pred_obj['image_id']))
|
34 |
+
image_id = pred_obj['image_id']
|
35 |
+
if "text_ids" not in pred_obj:
|
36 |
+
raise Exception('There exists one line not containing the predicted text_ids: {}'.format(line))
|
37 |
+
text_ids = pred_obj["text_ids"]
|
38 |
+
if not isinstance(text_ids, list):
|
39 |
+
raise Exception('The text_ids field of image_id {} is not a list, please check your schema'.format(image_id))
|
40 |
+
# check whether there are K products for each text
|
41 |
+
if len(text_ids) != k:
|
42 |
+
raise Exception('Image_id {} has wrong number of predicted text_ids! Require {}, but {} founded.'.format(image_id, k, len(text_ids)))
|
43 |
+
# check whether there exist an invalid prediction for any text
|
44 |
+
for rank, text_id in enumerate(text_ids):
|
45 |
+
if not isinstance(text_id, int):
|
46 |
+
raise Exception('Image_id {} has an invalid predicted text_id {} at rank {}, it should be an integer (not string), please check your schema'.format(image_id, text_id, rank + 1))
|
47 |
+
# check whether there are duplicate predicted products for a single text
|
48 |
+
if len(set(text_ids)) != k:
|
49 |
+
raise Exception('Image_id {} has duplicate products in your prediction. Pleace check again!'.format(image_id))
|
50 |
+
submission_dict[image_id] = text_ids # here we save the list of product ids
|
51 |
+
|
52 |
+
# check if any text is missing in the submission
|
53 |
+
pred_image_ids = set(submission_dict.keys())
|
54 |
+
nopred_image_ids = ref_image_ids - pred_image_ids
|
55 |
+
if len(nopred_image_ids) != 0:
|
56 |
+
raise Exception('The following image_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_image_ids])))
|
57 |
+
|
58 |
+
return submission_dict
|
59 |
+
|
60 |
+
|
61 |
+
def dump_2_json(info, path):
|
62 |
+
with open(path, 'w') as output_json_file:
|
63 |
+
json.dump(info, output_json_file)
|
64 |
+
|
65 |
+
|
66 |
+
def report_error_msg(detail, showMsg, out_p):
|
67 |
+
error_dict=dict()
|
68 |
+
error_dict['errorDetail']=detail
|
69 |
+
error_dict['errorMsg']=showMsg
|
70 |
+
error_dict['score']=0
|
71 |
+
error_dict['scoreJson']={}
|
72 |
+
error_dict['success']=False
|
73 |
+
dump_2_json(error_dict,out_p)
|
74 |
+
|
75 |
+
|
76 |
+
def report_score(r1, r5, r10, out_p):
|
77 |
+
result = dict()
|
78 |
+
result['success']=True
|
79 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
80 |
+
result['score'] = mean_recall * 100
|
81 |
+
result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
|
82 |
+
dump_2_json(result,out_p)
|
83 |
+
|
84 |
+
|
85 |
+
def read_reference(path):
|
86 |
+
fin = open(path)
|
87 |
+
reference = dict()
|
88 |
+
for line in fin:
|
89 |
+
line = line.strip()
|
90 |
+
obj = json.loads(line)
|
91 |
+
reference[obj['image_id']] = obj['text_ids']
|
92 |
+
return reference
|
93 |
+
|
94 |
+
def compute_score(golden_file, predict_file):
|
95 |
+
# read ground-truth
|
96 |
+
reference = read_reference(golden_file)
|
97 |
+
|
98 |
+
# read predictions
|
99 |
+
k = 10
|
100 |
+
predictions = read_submission(predict_file, reference, k)
|
101 |
+
|
102 |
+
# compute score for each text
|
103 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
104 |
+
for qid in reference.keys():
|
105 |
+
ground_truth_ids = set(reference[qid])
|
106 |
+
top10_pred_ids = predictions[qid]
|
107 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
108 |
+
r1_stat += 1
|
109 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
110 |
+
r5_stat += 1
|
111 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
112 |
+
r10_stat += 1
|
113 |
+
# the higher score, the better
|
114 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
115 |
+
mean_recall = (r1 + r5 + r10) / 3.0
|
116 |
+
result = [mean_recall, r1, r5, r10]
|
117 |
+
result = [score * 100 for score in result]
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
if __name__=="__main__":
|
122 |
+
# the path of answer json file (eg. test_queries_answers.jsonl)
|
123 |
+
standard_path = sys.argv[1]
|
124 |
+
# the path of prediction file (eg. example_pred.jsonl)
|
125 |
+
submit_path = sys.argv[2]
|
126 |
+
# the score will be dumped into this output json file
|
127 |
+
out_path = sys.argv[3]
|
128 |
+
|
129 |
+
print("Read standard from %s" % standard_path)
|
130 |
+
print("Read user submit file from %s" % submit_path)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# read ground-truth
|
134 |
+
reference = read_reference(standard_path)
|
135 |
+
|
136 |
+
# read predictions
|
137 |
+
k = 10
|
138 |
+
predictions = read_submission(submit_path, reference, k)
|
139 |
+
|
140 |
+
# compute score for each text
|
141 |
+
r1_stat, r5_stat, r10_stat = 0, 0, 0
|
142 |
+
for qid in reference.keys():
|
143 |
+
ground_truth_ids = set(reference[qid])
|
144 |
+
top10_pred_ids = predictions[qid]
|
145 |
+
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
|
146 |
+
r1_stat += 1
|
147 |
+
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
|
148 |
+
r5_stat += 1
|
149 |
+
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
|
150 |
+
r10_stat += 1
|
151 |
+
# the higher score, the better
|
152 |
+
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
|
153 |
+
report_score(r1, r5, r10, out_path)
|
154 |
+
print("The evaluation finished successfully.")
|
155 |
+
except Exception as e:
|
156 |
+
report_error_msg(e.args[0], e.args[0], out_path)
|
157 |
+
print("The evaluation failed: {}".format(e.args[0]))
|
Model/CLIP/cn_clip/eval/extract_features.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script extracts image and text features for evaluation. (with single-GPU)
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import json
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from cn_clip.clip.model import convert_weights, CLIP
|
16 |
+
from cn_clip.training.main import convert_models_to_fp32
|
17 |
+
from cn_clip.eval.data import get_eval_img_dataset, get_eval_txt_dataset
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument(
|
22 |
+
'--extract-image-feats',
|
23 |
+
action="store_true",
|
24 |
+
default=False,
|
25 |
+
help="Whether to extract image features."
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
'--extract-text-feats',
|
29 |
+
action="store_true",
|
30 |
+
default=False,
|
31 |
+
help="Whether to extract text features."
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
'--image-data',
|
35 |
+
type=str,
|
36 |
+
default="../Multimodal_Retrieval/lmdb/test/imgs",
|
37 |
+
help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings."
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
'--text-data',
|
41 |
+
type=str,
|
42 |
+
default="../Multimodal_Retrieval/test_texts.jsonl",
|
43 |
+
help="If --extract-text-feats is True, specify the path of input text Jsonl file."
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
'--image-feat-output-path',
|
47 |
+
type=str,
|
48 |
+
default=None,
|
49 |
+
help="If --extract-image-feats is True, specify the path of output image features."
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
'--text-feat-output-path',
|
53 |
+
type=str,
|
54 |
+
default=None,
|
55 |
+
help="If --extract-image-feats is True, specify the path of output text features."
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--img-batch-size", type=int, default=64, help="Image batch size."
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--text-batch-size", type=int, default=64, help="Text batch size."
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)."
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--resume",
|
68 |
+
default=None,
|
69 |
+
type=str,
|
70 |
+
help="path to latest checkpoint (default: none)",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--precision",
|
74 |
+
choices=["amp", "fp16", "fp32"],
|
75 |
+
default="amp",
|
76 |
+
help="Floating point precition."
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--vision-model",
|
80 |
+
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
|
81 |
+
default="ViT-B-16",
|
82 |
+
help="Name of the vision backbone to use.",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--text-model",
|
86 |
+
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
|
87 |
+
default="RoBERTa-wwm-ext-base-chinese",
|
88 |
+
help="Name of the text backbone to use.",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--debug",
|
92 |
+
default=False,
|
93 |
+
action="store_true",
|
94 |
+
help="If true, more information is logged."
|
95 |
+
)
|
96 |
+
args = parser.parse_args()
|
97 |
+
|
98 |
+
return args
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
args = parse_args()
|
103 |
+
|
104 |
+
assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!"
|
105 |
+
|
106 |
+
# Log params.
|
107 |
+
print("Params:")
|
108 |
+
for name in sorted(vars(args)):
|
109 |
+
val = getattr(args, name)
|
110 |
+
print(f" {name}: {val}")
|
111 |
+
|
112 |
+
args.gpu = 0
|
113 |
+
torch.cuda.set_device(args.gpu)
|
114 |
+
|
115 |
+
# Initialize the model.
|
116 |
+
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
|
117 |
+
print('Loading vision model config from', vision_model_config_file)
|
118 |
+
assert os.path.exists(vision_model_config_file)
|
119 |
+
|
120 |
+
text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
|
121 |
+
print('Loading text model config from', text_model_config_file)
|
122 |
+
assert os.path.exists(text_model_config_file)
|
123 |
+
|
124 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
125 |
+
model_info = json.load(fv)
|
126 |
+
if isinstance(model_info['vision_layers'], str):
|
127 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
128 |
+
for k, v in json.load(ft).items():
|
129 |
+
model_info[k] = v
|
130 |
+
|
131 |
+
model = CLIP(**model_info)
|
132 |
+
convert_weights(model)
|
133 |
+
|
134 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
135 |
+
if args.precision == "amp" or args.precision == "fp32":
|
136 |
+
convert_models_to_fp32(model)
|
137 |
+
model.cuda(args.gpu)
|
138 |
+
if args.precision == "fp16":
|
139 |
+
convert_weights(model)
|
140 |
+
|
141 |
+
# Get data.
|
142 |
+
if args.extract_image_feats:
|
143 |
+
print("Preparing image inference dataset.")
|
144 |
+
img_data = get_eval_img_dataset(args)
|
145 |
+
if args.extract_text_feats:
|
146 |
+
print("Preparing text inference dataset.")
|
147 |
+
text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length)
|
148 |
+
|
149 |
+
# Resume from a checkpoint.
|
150 |
+
print("Begin to load model checkpoint from {}.".format(args.resume))
|
151 |
+
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
|
152 |
+
# Map model to be loaded to specified single gpu.
|
153 |
+
loc = "cuda:{}".format(args.gpu)
|
154 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
155 |
+
start_epoch = checkpoint["epoch"]
|
156 |
+
sd = checkpoint["state_dict"]
|
157 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
158 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
159 |
+
model.load_state_dict(sd)
|
160 |
+
print(
|
161 |
+
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
|
162 |
+
)
|
163 |
+
|
164 |
+
# Make inference for texts
|
165 |
+
if args.extract_text_feats:
|
166 |
+
print('Make inference for texts...')
|
167 |
+
if args.text_feat_output_path is None:
|
168 |
+
args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6])
|
169 |
+
write_cnt = 0
|
170 |
+
with open(args.text_feat_output_path, "w") as fout:
|
171 |
+
model.eval()
|
172 |
+
dataloader = text_data.dataloader
|
173 |
+
with torch.no_grad():
|
174 |
+
for batch in tqdm(dataloader):
|
175 |
+
text_ids, texts = batch
|
176 |
+
texts = texts.cuda(args.gpu, non_blocking=True)
|
177 |
+
text_features = model(None, texts)
|
178 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
179 |
+
for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()):
|
180 |
+
fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature})))
|
181 |
+
write_cnt += 1
|
182 |
+
print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path))
|
183 |
+
|
184 |
+
# Make inference for images
|
185 |
+
if args.extract_image_feats:
|
186 |
+
print('Make inference for images...')
|
187 |
+
if args.image_feat_output_path is None:
|
188 |
+
# by default, we store the image features under the same directory with the text features
|
189 |
+
args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs"))
|
190 |
+
write_cnt = 0
|
191 |
+
with open(args.image_feat_output_path, "w") as fout:
|
192 |
+
model.eval()
|
193 |
+
dataloader = img_data.dataloader
|
194 |
+
with torch.no_grad():
|
195 |
+
for batch in tqdm(dataloader):
|
196 |
+
image_ids, images = batch
|
197 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
198 |
+
image_features = model(images, None)
|
199 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
200 |
+
for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()):
|
201 |
+
fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature})))
|
202 |
+
write_cnt += 1
|
203 |
+
print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path))
|
204 |
+
|
205 |
+
print("Done!")
|
Model/CLIP/cn_clip/eval/imagenet_zeroshot_templates.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script records the imagenet classnames and templates (both translated in Chinese)
|
4 |
+
used for zero-shot evaluation.
|
5 |
+
|
6 |
+
The original classnames and templates in English are derived from open_clip
|
7 |
+
(https://github.com/mlfoundations/open_clip/blob/main/src/training/imagenet_zeroshot_data.py)
|
8 |
+
The translated classnames and templates in Chinese are derived from wukong
|
9 |
+
(https://gitee.com/mindspore/models/tree/master/research/mm/wukong)
|
10 |
+
'''
|
11 |
+
|
12 |
+
imagenet_classnames = [ "丁鲷", "金鱼", "大白鲨", "虎鲨", "锤头鲨", "电鳐", "黄貂鱼", "公鸡", "母鸡", "鸵鸟",
|
13 |
+
"燕雀", "金翅雀", "家朱雀", "灯芯草雀", "靛蓝雀", "蓝鹀", "夜莺", "松鸦", "喜鹊", "山雀",
|
14 |
+
"河鸟", "鸢(猛禽)", "秃头鹰", "秃鹫", "大灰猫头鹰", "欧洲火蝾螈", "普通蝾螈", "水蜥", "斑点蝾螈", "蝾螈",
|
15 |
+
"牛蛙", "树蛙", "尾蛙", "红海龟", "皮革龟", "泥龟", "淡水龟", "箱龟", "带状壁虎", "普通鬣蜥",
|
16 |
+
"美国变色龙", "鞭尾蜥蜴", "飞龙科蜥蜴", "褶边蜥蜴", "鳄鱼蜥蜴", "毒蜥", "绿蜥蜴", "非洲变色龙", "科莫多蜥蜴", "非洲鳄",
|
17 |
+
"美国鳄鱼", "三角龙", "雷蛇", "环蛇", "希腊蛇", "绿蛇", "国王蛇", "袜带蛇", "水蛇", "藤蛇",
|
18 |
+
"夜蛇", "大蟒蛇", "岩石蟒蛇", "印度眼镜蛇", "绿曼巴", "海蛇", "角腹蛇", "菱纹响尾蛇", "角响尾蛇", "三叶虫",
|
19 |
+
"盲蜘蛛", "蝎子", "黑金花园蜘蛛", "谷仓蜘蛛", "花园蜘蛛", "黑寡妇蜘蛛", "狼蛛", "狼蜘蛛", "壁虱", "蜈蚣",
|
20 |
+
"黑松鸡", "松鸡", "披肩鸡", "草原鸡", "孔雀", "鹌鹑", "鹧鸪", "非洲灰鹦鹉", "金刚鹦鹉", "硫冠鹦鹉",
|
21 |
+
"短尾鹦鹉", "褐翅鸦鹃", "食蜂鸟;蜂虎", "犀鸟", "蜂鸟", "鹟䴕", "巨嘴鸟;大嘴鸟", "野鸭", "红胸秋沙鸭", "鹅",
|
22 |
+
"黑天鹅", "大象", "针鼹鼠", "鸭嘴兽", "沙袋鼠", "考拉", "袋熊", "水母", "海葵", "脑珊瑚",
|
23 |
+
"扁形虫扁虫", "线虫", "海螺", "蜗牛", "鼻涕虫", "海蛞蝓;海参", "石鳖", "鹦鹉螺", "珍宝蟹", "石蟹",
|
24 |
+
"招潮蟹", "帝王蟹", "美国龙虾", "大螯虾", "小龙虾", "寄居蟹", "等足目动物(明虾和螃蟹近亲)", "白鹳", "黑鹳", "鹭",
|
25 |
+
"火烈鸟", "小蓝鹭", "美国鹭", "麻鸦", "鹤", "秧鹤", "欧洲水鸡", "沼泽泥母鸡", "鸨", "红翻石鹬",
|
26 |
+
"红背鹬", "红脚鹬", "半蹼鹬", "蛎鹬", "鹈鹕", "国王企鹅", "信天翁", "灰鲸", "杀人鲸", "海牛",
|
27 |
+
"海狮", "吉娃娃", "日本狆犬", "马尔济斯犬", "狮子狗", "西施犬", "布莱尼姆猎犬", "巴比狗", "玩具犬", "罗得西亚长背猎狗",
|
28 |
+
"阿富汗猎犬", "巴吉度猎犬", "比格犬", "侦探犬", "蓝色快狗", "黑褐猎浣熊犬", "沃克猎犬", "英国猎狐犬", "美洲赤狗", "俄罗斯猎狼犬",
|
29 |
+
"爱尔兰猎狼犬", "意大利灰狗", "惠比特犬", "依比沙猎犬", "挪威猎犬", "奥达猎犬", "沙克犬", "苏格兰猎鹿犬", "威玛猎犬", "斯塔福德郡斗牛犬",
|
30 |
+
"美国斯塔福德郡梗", "贝德灵顿梗", "边境梗", "凯丽蓝梗", "爱尔兰梗", "诺福克梗", "诺维奇梗", "约克犬;约克夏梗犬", "刚毛猎狐梗", "莱克兰梗",
|
31 |
+
"锡利哈姆梗", "艾尔谷犬", "凯恩梗", "澳大利亚梗", "丹迪丁蒙梗", "波士顿梗", "迷你雪纳瑞犬", "巨型雪纳瑞犬", "标准雪纳瑞犬", "苏格兰梗犬",
|
32 |
+
"西藏梗", "丝毛梗", "爱尔兰软毛梗犬", "西高地白梗", "拉萨阿普索犬", "平毛寻回犬", "卷毛寻回犬", "金毛猎犬", "拉布拉多猎犬", "乞沙比克猎犬",
|
33 |
+
"德国短毛指示犬", "维兹拉犬", "英国塞特犬", "爱尔兰雪达犬", "戈登雪达犬", "布列塔尼犬猎犬", "黄毛", "英国史宾格犬", "威尔士史宾格犬", "可卡犬",
|
34 |
+
"萨塞克斯猎犬", "爱尔兰水猎犬", "哥威斯犬", "舒柏奇犬", "比利时牧羊犬", "马里努阿犬", "伯瑞犬", "凯尔皮犬", "匈牙利牧羊犬", "老英国牧羊犬",
|
35 |
+
"喜乐蒂牧羊犬", "牧羊犬", "边境牧羊犬", "法兰德斯牧牛狗", "罗特韦尔犬", "德国牧羊犬", "多伯曼犬", "鹿犬;迷你杜宾犬", "大瑞士山地犬", "伯恩山犬",
|
36 |
+
"阿策尔山犬", "恩特尔布赫山犬", "拳师狗", "斗牛獒", "藏獒", "法国斗牛犬", "大丹犬", "圣伯纳德狗", "爱斯基摩犬", "阿拉斯加雪橇犬",
|
37 |
+
"哈士奇", "达尔马提亚", "狮毛狗", "巴辛吉狗", "八哥犬", "莱昂贝格狗", "纽芬兰犬", "大白熊犬", "萨摩耶犬", "博美犬",
|
38 |
+
"松狮", "凯斯犬", "布鲁塞尔格林芬犬", "彭布洛克威尔士科基犬", "威尔士柯基犬", "玩具贵宾犬", "迷你贵宾犬", "标准贵宾犬", "墨西哥无毛犬", "灰狼",
|
39 |
+
"白狼", "红太狼", "狼", "澳洲野狗", "豺", "非洲猎犬", "鬣狗", "红狐狸", "沙狐", "北极狐狸",
|
40 |
+
"灰狐狸", "虎斑猫", "山猫", "波斯猫", "暹罗猫", "埃及猫", "美洲狮", "猞猁", "豹子", "雪豹",
|
41 |
+
"美洲虎", "狮子", "老虎", "猎豹", "棕熊", "美洲黑熊", "冰熊", "懒熊", "獴", "猫鼬",
|
42 |
+
"虎甲虫", "瓢虫", "土鳖虫", "天牛", "龟甲虫", "粪甲虫", "犀牛甲虫", "象甲", "苍蝇", "蜜蜂",
|
43 |
+
"蚂蚁", "蚱蜢", "蟋蟀", "竹节虫", "蟑螂", "螳螂", "蝉", "叶蝉", "草蜻蛉", "蜻蜓",
|
44 |
+
"豆娘", "优红蛱蝶", "小环蝴蝶", "君主蝴蝶", "菜粉蝶", "白蝴蝶", "灰蝶", "海星", "海胆", "海黄瓜;海参",
|
45 |
+
"野兔", "兔", "安哥拉兔", "仓鼠", "刺猬", "黑松鼠", "土拨鼠", "海狸", "豚鼠", "栗色马",
|
46 |
+
"斑马", "猪", "野猪", "疣猪", "河马", "牛", "水牛", "野牛", "公羊", "大角羊",
|
47 |
+
"山羊", "狷羚", "黑斑羚", "瞪羚", "阿拉伯单峰骆驼", "骆驼", "黄鼠狼", "水貂", "臭猫", "黑足鼬",
|
48 |
+
"水獭", "臭鼬", "獾", "犰狳", "树懒", "猩猩", "大猩猩", "黑猩猩", "长臂猿", "合趾猿长臂猿",
|
49 |
+
"长尾猴", "赤猴", "狒狒", "恒河猴", "白头叶猴", "疣猴", "长鼻猴", "狨(美洲产小型长尾猴)", "卷尾猴", "吼猴",
|
50 |
+
"伶猴", "蜘蛛猴", "松鼠猴", "马达加斯加环尾狐猴", "大狐猴", "印度大象", "非洲象", "小熊猫", "大熊猫", "杖鱼",
|
51 |
+
"鳗鱼", "银鲑", "三色刺蝶鱼", "海葵鱼", "鲟鱼", "雀鳝", "狮子鱼", "河豚", "算盘", "长袍",
|
52 |
+
"学位袍", "手风琴", "原声吉他", "航空母舰", "客机", "飞艇", "祭坛", "救护车", "水陆两用车", "模拟时钟",
|
53 |
+
"蜂房", "围裙", "垃圾桶", "攻击步枪", "背包", "面包店", "平衡木", "热气球", "圆珠笔", "创可贴",
|
54 |
+
"班卓琴", "栏杆", "杠铃", "理发师的椅子", "理发店", "牲口棚", "晴雨表", "圆筒", "园地小车", "棒球",
|
55 |
+
"篮球", "婴儿床", "巴松管", "游泳帽", "沐浴毛巾", "浴缸", "沙滩车", "灯塔", "烧杯", "熊皮高帽",
|
56 |
+
"啤酒瓶", "啤酒杯", "钟塔", "(小儿用的)围嘴", "串联自行车", "比基尼", "装订册", "双筒望远镜", "鸟舍", "船库",
|
57 |
+
"双人雪橇", "饰扣式领带", "阔边女帽", "书橱", "书店", "瓶盖", "弓箭", "蝴蝶结领结", "铜制牌位", "奶罩",
|
58 |
+
"防波堤", "铠甲", "扫帚", "桶", "扣环", "防弹背心", "动车", "肉铺", "出租车", "大锅",
|
59 |
+
"蜡烛", "大炮", "独木舟", "开瓶器", "开衫", "车镜", "旋转木马", "木匠的工具包", "纸箱", "车轮",
|
60 |
+
"取款机", "盒式录音带", "卡带播放器", "城堡", "双体船", "CD播放器", "大提琴", "移动电话", "铁链", "围栏",
|
61 |
+
"链甲", "电锯", "箱子", "梳妆台", "编钟", "中国橱柜", "圣诞袜", "教堂", "电影院", "切肉刀",
|
62 |
+
"悬崖屋", "斗篷", "木屐", "鸡尾酒调酒器", "咖啡杯", "咖啡壶", "螺旋结构(楼梯)", "组合锁", "电脑键盘", "糖果",
|
63 |
+
"集装箱船", "敞篷车", "瓶塞钻", "短号", "牛仔靴", "牛仔帽", "摇篮", "起重机", "头盔", "板条箱",
|
64 |
+
"小儿床", "砂锅", "槌球", "拐杖", "胸甲", "大坝", "书桌", "台式电脑", "有线电话", "尿布湿",
|
65 |
+
"数字时钟", "数字手表", "餐桌板", "抹布", "洗碗机", "盘式制动器", "码头", "狗拉雪橇", "圆顶", "门垫",
|
66 |
+
"钻井平台", "鼓", "鼓槌", "哑铃", "荷兰烤箱", "电风扇", "电吉他", "电力机车", "组合电视柜", "信封",
|
67 |
+
"浓缩咖啡机", "扑面粉", "女用长围巾", "文件", "消防船", "消防车", "火炉栏", "旗杆", "长笛", "折叠椅",
|
68 |
+
"橄榄球头盔", "叉车", "喷泉", "钢笔", "有四根帷柱的床", "运货车厢", "圆号", "煎锅", "裘皮大衣", "垃圾车",
|
69 |
+
"防毒面具", "汽油泵", "高脚杯", "卡丁车", "高尔夫球", "高尔夫球车", "狭长小船", "锣", "礼服", "钢琴",
|
70 |
+
"温室", "散热器格栅", "杂货店", "断头台", "小发夹", "头发喷雾", "半履带装甲车", "锤子", "大篮子", "手摇鼓风机",
|
71 |
+
"手提电脑", "手帕", "硬盘", "口琴", "竖琴", "收割机", "斧头", "手枪皮套", "家庭影院", "蜂窝",
|
72 |
+
"钩爪", "衬裙", "单杠", "马车", "沙漏", "iPod", "熨斗", "南瓜灯笼", "牛仔裤", "吉普车",
|
73 |
+
"T恤衫", "拼图", "人力车", "操纵杆", "和服", "护膝", "蝴蝶结", "大褂", "长柄勺", "灯罩",
|
74 |
+
"笔记本电脑", "割草机", "镜头盖", "开信刀;拆信刀", "图书馆", "救生艇", "点火器", "豪华轿车", "远洋班轮", "唇膏",
|
75 |
+
"平底便鞋", "洗剂", "扬声器", "放大镜", "锯木厂", "磁罗盘", "邮袋", "信箱", "女游泳衣", "有肩带浴衣",
|
76 |
+
"窨井盖", "沙球(一种打击乐器)", "马林巴木琴", "面膜", "火柴", "花柱", "迷宫", "量杯", "药箱", "巨石",
|
77 |
+
"麦克风", "微波炉", "军装", "奶桶", "迷你巴士", "迷你裙", "面包车;小型货车", "导弹", "连指手套", "搅拌钵",
|
78 |
+
"活动房屋(由汽车拖拉的)", "福特T型车", "调制解调器;光猫", "修道院", "显示器", "电瓶车", "砂浆", "学士", "清真寺", "蚊帐",
|
79 |
+
"摩托车", "山地自行车", "登山帐", "鼠标", "捕鼠器", "搬家货车", "动物的口套", "金属钉子", "颈托", "项链",
|
80 |
+
"乳头(瓶)", "平板电脑", "方尖碑", "双簧管", "小鹅笛;球形笛(管身椭圆形)", "里程表", "滤油器", "风琴", "示波器", "罩裙",
|
81 |
+
"牛车", "氧气面罩", "包装", "船桨", "明轮", "挂锁", "画笔", "睡衣", "宫殿", "排箫",
|
82 |
+
"纸巾", "降落伞", "双杠", "公园长椅", "停车收费表", "客车", "露台", "付费电话", "基座", "铅笔盒",
|
83 |
+
"卷笔刀", "香水(瓶)", "培养皿", "复印机", "拨弦片", "尖顶头盔", "用尖板条连成的尖桩篱栅", "皮卡", "桥墩", "存钱罐",
|
84 |
+
"药瓶", "枕头", "乒乓球", "风车", "海盗船", "水罐", "木工刨", "天文馆", "塑料袋", "板架",
|
85 |
+
"犁型铲雪机", "手压皮碗泵", "宝丽来相机", "电线杆", "警车", "雨披", "台球桌", "充气饮料瓶", "花盆", "陶工旋盘",
|
86 |
+
"电钻", "祈祷垫", "打印机", "监狱", "炮弹", "投影仪", "冰球", "沙包", "小钱袋;手袋", "羽管笔",
|
87 |
+
"被子", "赛车", "球拍", "散热器", "收音机", "射电望远镜", "雨桶", "休闲车", "卷轴", "反射式照相机",
|
88 |
+
"冰箱", "遥控器", "餐厅", "左轮手枪", "步枪", "摇椅", "电转烤肉架", "橡皮", "橄榄球", "直尺",
|
89 |
+
"跑步鞋", "保险柜", "安全别针", "盐瓶(调味用)", "凉鞋", "纱笼", "萨克斯管", "剑鞘", "秤", "校车",
|
90 |
+
"帆船", "记分牌", "屏幕", "螺丝", "螺丝刀", "安全带", "缝纫机", "盾牌", "皮鞋店", "障子",
|
91 |
+
"购物篮", "购物车", "铁锹", "浴帽", "浴帘", "滑雪板", "滑雪面罩", "睡袋", "滑尺", "滑动门",
|
92 |
+
"角子老虎机", "潜水通气管", "摩托雪橇;雪地机动车", "扫雪机", "皂液器", "足球", "袜子", "碟式太阳能", "宽边帽", "汤碗",
|
93 |
+
"空格键", "空间加热器", "航天飞机", "锅铲;做饭的铲子", "快艇", "蜘蛛网", "纺锤;手纺用的绕线杆", "跑车", "聚光灯", "舞台",
|
94 |
+
"蒸汽机车", "钢拱桥", "钢滚筒", "听诊器", "女用披肩", "石头墙", "秒表", "火炉", "过滤器", "有轨电车",
|
95 |
+
"担架", "沙发床", "佛塔", "潜艇", "套装", "日晷", "太阳镜", "太阳镜", "防晒霜", "悬索桥",
|
96 |
+
"拖把", "运动衫", "游泳裤", "秋千", "开关", "注射器;吸管", "台灯", "坦克", "录音机", "茶壶",
|
97 |
+
"泰迪", "电视", "网球;打网球的球", "茅草", "幕布", "顶针", "打谷机;脱粒机", "宝座", "瓦屋顶", "烤面包机",
|
98 |
+
"烟草店", "马桶", "火炬", "图腾柱", "拖车;牵引车", "玩具店", "拖拉机", "半挂汽车", "托盘", "风衣",
|
99 |
+
"三轮车", "三体船", "三脚架", "凯旋门", "无轨电车", "长号", "浴盆", "旋转式栅门", "打字机键盘", "伞",
|
100 |
+
"独轮车", "直立式钢琴", "吸尘器", "花瓶;装饰瓶", "拱顶", "天鹅绒", "自动售货机", "法衣;祭衣;祭服", "高架桥", "小提琴",
|
101 |
+
"排球", "松饼机", "挂钟", "钱包;钱夹", "衣柜衣橱", "军用飞机", "洗脸盆", "洗衣机", "水瓶", "水壶",
|
102 |
+
"水塔", "威士忌壶", "哨子", "假发", "纱窗", "百叶窗", "温莎领带", "葡萄酒瓶", "飞机翅膀", "炒菜锅",
|
103 |
+
"木勺子;木头勺子", "毛织品", "原木栅栏", "沉船", "双桅船", "蒙古包", "网站;网页", "漫画", "纵横字谜", "路标",
|
104 |
+
"交通信号灯", "防尘罩", "菜单", "盘子", "墨西哥鳄梨酱;墨西哥牛油果酱", "清炖肉汤", "火锅", "乳脂蛋糕;英国甜点", "冰淇淋", "冰棍;雪糕",
|
105 |
+
"法式面包", "百吉饼", "椒盐脆饼", "芝士汉堡", "热狗", "土豆泥", "结球甘蓝", "西兰花;绿菜花", "菜花;花椰菜", "西葫芦",
|
106 |
+
"金丝瓜;意面南瓜;面条瓜", "绿色小南瓜;青南瓜", "南瓜", "黄瓜", "洋蓟;球蓟", "甜椒", "刺棘蓟", "蘑菇", "绿苹果", "草莓",
|
107 |
+
"橘子", "柠檬", "无花果", "菠萝", "香蕉", "菠萝蜜", "番荔枝", "石榴", "干草", "培根蛋酱意大利面",
|
108 |
+
"巧克力酱", "生面;面团", "瑞士肉包", "披萨", "馅饼", "卷饼", "红葡萄酒", "意式浓缩咖啡", "杯子", "蛋酒",
|
109 |
+
"高山", "泡泡", "悬崖", "珊瑚礁", "间歇泉;间断喷发的温泉", "湖边", "岬角;深入海中的狭长高地", "沙洲", "沙滩", "峡谷",
|
110 |
+
"火山", "棒球运动员", "新郎", "潜水员", "油菜", "雏菊", "黄色杓兰", "玉米", "橡子", "玫瑰果",
|
111 |
+
"七叶树果实", "珊瑚菌", "木耳", "鹿花菌", "臭角菇", "地星", "多叶奇果菌", "牛肝菌", "玉米棒子", "卫生纸"]
|
112 |
+
|
113 |
+
openai_imagenet_template = [
|
114 |
+
lambda c: f'{c}的照片。',
|
115 |
+
lambda c: f'质量差的{c}的照片。',
|
116 |
+
lambda c: f'许多{c}的照片。',
|
117 |
+
lambda c: f'{c}的雕塑。',
|
118 |
+
lambda c: f'难以看到{c}的照片。',
|
119 |
+
lambda c: f'{c}的低分辨率照片。',
|
120 |
+
lambda c: f'{c}的渲染。',
|
121 |
+
lambda c: f'涂鸦{c}。',
|
122 |
+
lambda c: f'{c}的糟糕照片。',
|
123 |
+
lambda c: f'{c}的裁剪照片。',
|
124 |
+
lambda c: f'{c}的纹身。',
|
125 |
+
lambda c: f'{c}的刺绣照片。',
|
126 |
+
lambda c: f'很难看到{c}的照片。',
|
127 |
+
lambda c: f'{c}的明亮照片。',
|
128 |
+
lambda c: f'一张干净的{c}的照片。',
|
129 |
+
lambda c: f'一张包含{c}的照片。',
|
130 |
+
lambda c: f'{c}的深色照片。',
|
131 |
+
lambda c: f'{c}的手绘画。',
|
132 |
+
lambda c: f'���的{c}的照片。',
|
133 |
+
lambda c: f'不自然的{c}的照片。',
|
134 |
+
lambda c: f'一张酷的{c}的照片。',
|
135 |
+
lambda c: f'{c}的特写照片。',
|
136 |
+
lambda c: f'{c}的黑白照片。',
|
137 |
+
lambda c: f'一幅{c}的画。',
|
138 |
+
lambda c: f'一幅{c}的绘画。',
|
139 |
+
lambda c: f'一张{c}的像素照片。',
|
140 |
+
lambda c: f'{c}的雕像。',
|
141 |
+
lambda c: f'一张{c}的明亮照片。',
|
142 |
+
lambda c: f'{c}的裁剪照片。',
|
143 |
+
lambda c: f'人造的{c}的照片。',
|
144 |
+
lambda c: f'一张关于{c}的照片。',
|
145 |
+
lambda c: f'损坏的{c}的jpeg照片。',
|
146 |
+
lambda c: f'{c}的模糊照片。',
|
147 |
+
lambda c: f'{c}的相片。',
|
148 |
+
lambda c: f'一张{c}的好照片。',
|
149 |
+
lambda c: f'{c}的渲染照。',
|
150 |
+
lambda c: f'视频游戏中的{c}。',
|
151 |
+
lambda c: f'一张{c}的照片。',
|
152 |
+
lambda c: f'{c}的涂鸦。',
|
153 |
+
lambda c: f'{c}的近距离照片。',
|
154 |
+
lambda c: f'{c}的折纸。',
|
155 |
+
lambda c: f'{c}在视频游戏中。',
|
156 |
+
lambda c: f'{c}的草图。',
|
157 |
+
lambda c: f'{c}的涂鸦照。',
|
158 |
+
lambda c: f'{c}的折纸形状。',
|
159 |
+
lambda c: f'低分辨率的{c}的照片。',
|
160 |
+
lambda c: f'玩具{c}。',
|
161 |
+
lambda c: f'{c}的副本。',
|
162 |
+
lambda c: f'{c}的干净的照片。',
|
163 |
+
lambda c: f'一张大{c}的照片。',
|
164 |
+
lambda c: f'{c}的重现。',
|
165 |
+
lambda c: f'一张漂亮的{c}的照片。',
|
166 |
+
lambda c: f'一张奇怪的{c}的照片。',
|
167 |
+
lambda c: f'模糊的{c}的照片。',
|
168 |
+
lambda c: f'卡通{c}。',
|
169 |
+
lambda c: f'{c}的艺术作品。',
|
170 |
+
lambda c: f'{c}的素描。',
|
171 |
+
lambda c: f'刺绣{c}。',
|
172 |
+
lambda c: f'{c}的像素照。',
|
173 |
+
lambda c: f'{c}的拍照。',
|
174 |
+
lambda c: f'{c}的损坏的照片。',
|
175 |
+
lambda c: f'高质量的{c}的照片。',
|
176 |
+
lambda c: f'毛绒玩具{c}。',
|
177 |
+
lambda c: f'漂亮的{c}的照片。',
|
178 |
+
lambda c: f'小{c}的照片。',
|
179 |
+
lambda c: f'照片是奇怪的{c}。',
|
180 |
+
lambda c: f'漫画{c}。',
|
181 |
+
lambda c: f'{c}的艺术照。',
|
182 |
+
lambda c: f'{c}的图形。',
|
183 |
+
lambda c: f'大{c}的照片。',
|
184 |
+
lambda c: f'黑白的{c}的照片。',
|
185 |
+
lambda c: f'{c}毛绒玩具。',
|
186 |
+
lambda c: f'一张{c}的深色照片。',
|
187 |
+
lambda c: f'{c}的摄影图。',
|
188 |
+
lambda c: f'{c}的涂鸦照。',
|
189 |
+
lambda c: f'玩具形状的{c}。',
|
190 |
+
lambda c: f'拍了{c}的照片。',
|
191 |
+
lambda c: f'酷酷的{c}的照片。',
|
192 |
+
lambda c: f'照片里的小{c}。',
|
193 |
+
lambda c: f'{c}的刺青。',
|
194 |
+
]
|
Model/CLIP/cn_clip/eval/make_topk_predictions.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs text-to-image prediction file for evaluation.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import numpy
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--image-feats',
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="Specify the path of image features."
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
'--text-feats',
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="Specify the path of text features."
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
'--top-k',
|
30 |
+
type=int,
|
31 |
+
default=10,
|
32 |
+
help="Specify the k value of top-k predictions."
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
'--eval-batch-size',
|
36 |
+
type=int,
|
37 |
+
default=32768,
|
38 |
+
help="Specify the image-side batch size when computing the inner products, default to 8192"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--output',
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help="Specify the output jsonl prediction filepath."
|
45 |
+
)
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
args = parse_args()
|
50 |
+
|
51 |
+
# Log params.
|
52 |
+
print("Params:")
|
53 |
+
for name in sorted(vars(args)):
|
54 |
+
val = getattr(args, name)
|
55 |
+
print(f" {name}: {val}")
|
56 |
+
|
57 |
+
print("Begin to load image features...")
|
58 |
+
image_ids = []
|
59 |
+
image_feats = []
|
60 |
+
with open(args.image_feats, "r") as fin:
|
61 |
+
for line in tqdm(fin):
|
62 |
+
obj = json.loads(line.strip())
|
63 |
+
image_ids.append(obj['image_id'])
|
64 |
+
image_feats.append(obj['feature'])
|
65 |
+
image_feats_array = np.array(image_feats, dtype=np.float32)
|
66 |
+
print("Finished loading image features.")
|
67 |
+
|
68 |
+
print("Begin to compute top-{} predictions for texts...".format(args.top_k))
|
69 |
+
with open(args.output, "w") as fout:
|
70 |
+
with open(args.text_feats, "r") as fin:
|
71 |
+
for line in tqdm(fin):
|
72 |
+
obj = json.loads(line.strip())
|
73 |
+
text_id = obj['text_id']
|
74 |
+
text_feat = obj['feature']
|
75 |
+
score_tuples = []
|
76 |
+
text_feat_tensor = torch.tensor([text_feat], dtype=torch.float).cuda() # [1, feature_dim]
|
77 |
+
idx = 0
|
78 |
+
while idx < len(image_ids):
|
79 |
+
img_feats_tensor = torch.from_numpy(image_feats_array[idx : min(idx + args.eval_batch_size, len(image_ids))]).cuda() # [batch_size, feature_dim]
|
80 |
+
batch_scores = text_feat_tensor @ img_feats_tensor.t() # [1, batch_size]
|
81 |
+
for image_id, score in zip(image_ids[idx : min(idx + args.eval_batch_size, len(image_ids))], batch_scores.squeeze(0).tolist()):
|
82 |
+
score_tuples.append((image_id, score))
|
83 |
+
idx += args.eval_batch_size
|
84 |
+
top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
|
85 |
+
fout.write("{}\n".format(json.dumps({"text_id": text_id, "image_ids": [entry[0] for entry in top_k_predictions]})))
|
86 |
+
|
87 |
+
print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
|
88 |
+
print("Done!")
|
Model/CLIP/cn_clip/eval/make_topk_predictions_tr.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs image-to-text retrieval prediction file for evaluation.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import numpy
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--image-feats',
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="Specify the path of image features."
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
'--text-feats',
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="Specify the path of text features."
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
'--top-k',
|
30 |
+
type=int,
|
31 |
+
default=10,
|
32 |
+
help="Specify the k value of top-k predictions."
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
'--eval-batch-size',
|
36 |
+
type=int,
|
37 |
+
default=32768,
|
38 |
+
help="Specify the image-side batch size when computing the inner products, default to 8192"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--output',
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help="Specify the output jsonl prediction filepath."
|
45 |
+
)
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
args = parse_args()
|
50 |
+
|
51 |
+
# Log params.
|
52 |
+
print("Params:")
|
53 |
+
for name in sorted(vars(args)):
|
54 |
+
val = getattr(args, name)
|
55 |
+
print(f" {name}: {val}")
|
56 |
+
|
57 |
+
print("Begin to load text features...")
|
58 |
+
text_ids = []
|
59 |
+
text_feats = []
|
60 |
+
with open(args.text_feats, "r") as fin:
|
61 |
+
for line in tqdm(fin):
|
62 |
+
obj = json.loads(line.strip())
|
63 |
+
text_ids.append(obj['text_id'])
|
64 |
+
text_feats.append(obj['feature'])
|
65 |
+
text_feats_array = np.array(text_feats, dtype=np.float32)
|
66 |
+
print("Finished loading text features.")
|
67 |
+
|
68 |
+
print("Begin to compute top-{} predictions for images...".format(args.top_k))
|
69 |
+
with open(args.output, "w") as fout:
|
70 |
+
with open(args.image_feats, "r") as fin:
|
71 |
+
for line in tqdm(fin):
|
72 |
+
obj = json.loads(line.strip())
|
73 |
+
image_id = obj['image_id']
|
74 |
+
image_feat = obj['feature']
|
75 |
+
score_tuples = []
|
76 |
+
image_feat_tensor = torch.tensor([image_feat], dtype=torch.float).cuda() # [1, feature_dim]
|
77 |
+
idx = 0
|
78 |
+
while idx < len(text_ids):
|
79 |
+
text_feats_tensor = torch.from_numpy(text_feats_array[idx : min(idx + args.eval_batch_size, len(text_ids))]).cuda() # [batch_size, feature_dim]
|
80 |
+
batch_scores = image_feat_tensor @ text_feats_tensor.t() # [1, batch_size]
|
81 |
+
for text_id, score in zip(text_ids[idx : min(idx + args.eval_batch_size, len(text_ids))], batch_scores.squeeze(0).tolist()):
|
82 |
+
score_tuples.append((text_id, score))
|
83 |
+
idx += args.eval_batch_size
|
84 |
+
top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
|
85 |
+
fout.write("{}\n".format(json.dumps({"image_id": image_id, "text_ids": [entry[0] for entry in top_k_predictions]})))
|
86 |
+
|
87 |
+
print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
|
88 |
+
print("Done!")
|
Model/CLIP/cn_clip/eval/transform_ir_annotation_to_tr.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from tqdm import tqdm
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
'--input',
|
10 |
+
type=str,
|
11 |
+
required=True,
|
12 |
+
help="Input path of text-to-image Jsonl annotation file."
|
13 |
+
)
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
args = parse_args()
|
18 |
+
|
19 |
+
t2i_record = dict()
|
20 |
+
|
21 |
+
with open(args.input, "r") as fin:
|
22 |
+
for line in tqdm(fin):
|
23 |
+
obj = json.loads(line.strip())
|
24 |
+
text_id = obj['text_id']
|
25 |
+
image_ids = obj['image_ids']
|
26 |
+
for image_id in image_ids:
|
27 |
+
if image_id not in t2i_record:
|
28 |
+
t2i_record[image_id] = []
|
29 |
+
t2i_record[image_id].append(text_id)
|
30 |
+
|
31 |
+
with open(args.input.replace(".jsonl", "") + ".tr.jsonl", "w") as fout:
|
32 |
+
for image_id, text_ids in t2i_record.items():
|
33 |
+
out_obj = {"image_id": image_id, "text_ids": text_ids}
|
34 |
+
fout.write("{}\n".format(json.dumps(out_obj)))
|
35 |
+
|
36 |
+
print("Done!")
|
Model/CLIP/cn_clip/eval/zeroshot_evaluation.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
This script performs zero-shot evaluation on ImageNet-1K. (with single-GPU)
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
import json
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from cn_clip.clip.model import convert_weights, CLIP
|
15 |
+
from cn_clip.clip import tokenize
|
16 |
+
from cn_clip.training.main import convert_models_to_fp32
|
17 |
+
from cn_clip.clip.utils import image_transform
|
18 |
+
from cn_clip.eval.data import get_imagenet_dataset, _preprocess_text
|
19 |
+
from cn_clip.eval.imagenet_zeroshot_templates import imagenet_classnames, openai_imagenet_template
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"--vision-model",
|
26 |
+
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
|
27 |
+
default="ViT-B-16",
|
28 |
+
help="Name of the vision backbone to use.",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--text-model",
|
32 |
+
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
|
33 |
+
default="RoBERTa-wwm-ext-base-chinese",
|
34 |
+
help="Name of the text backbone to use.",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--precision",
|
38 |
+
choices=["amp", "fp16", "fp32"],
|
39 |
+
default="amp",
|
40 |
+
help="Floating point precition."
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--imagenet-val",
|
44 |
+
type=str,
|
45 |
+
required=True,
|
46 |
+
help="Path to imagenet val set for conducting zero shot evaluation.",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--img-batch-size", type=int, default=64, help="Image batch size."
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--context-length",
|
53 |
+
type=int,
|
54 |
+
default=32,
|
55 |
+
help="The maximum length of input text (include [CLS] & [SEP] tokens)."
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--resume",
|
59 |
+
default=None,
|
60 |
+
type=str,
|
61 |
+
help="path to latest checkpoint (default: none)",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--num-workers", type=int, default=4, help="Number of workers for ImageNet dataset."
|
65 |
+
)
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
return args
|
69 |
+
|
70 |
+
|
71 |
+
def zero_shot_classifier(model, classnames, templates, args):
|
72 |
+
with torch.no_grad():
|
73 |
+
zeroshot_weights = []
|
74 |
+
for classname in tqdm(classnames):
|
75 |
+
texts = [_preprocess_text(template(classname)) for template in templates] #format with class
|
76 |
+
texts = tokenize(texts, context_length=args.context_length).to(args.gpu) #tokenize
|
77 |
+
class_embeddings = model(None, texts)
|
78 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
79 |
+
class_embedding = class_embeddings.mean(dim=0)
|
80 |
+
class_embedding /= class_embedding.norm()
|
81 |
+
zeroshot_weights.append(class_embedding)
|
82 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
|
83 |
+
return zeroshot_weights
|
84 |
+
|
85 |
+
|
86 |
+
def accuracy(output, target, topk=(1,)):
|
87 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
88 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
89 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
90 |
+
|
91 |
+
|
92 |
+
def run(model, classifier, dataloader, args):
|
93 |
+
with torch.no_grad():
|
94 |
+
top1, top5, n = 0., 0., 0.
|
95 |
+
for images, target in tqdm(dataloader):
|
96 |
+
images = images.to(args.gpu)
|
97 |
+
target = target.to(args.gpu)
|
98 |
+
|
99 |
+
# predict
|
100 |
+
image_features = model(images, None)
|
101 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
102 |
+
logits = 100. * image_features @ classifier
|
103 |
+
|
104 |
+
# measure accuracy
|
105 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
106 |
+
top1 += acc1
|
107 |
+
top5 += acc5
|
108 |
+
n += images.size(0)
|
109 |
+
|
110 |
+
top1 = (top1 / n)
|
111 |
+
top5 = (top5 / n)
|
112 |
+
return top1, top5
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
args = parse_args()
|
117 |
+
|
118 |
+
# Log params.
|
119 |
+
print("Params:")
|
120 |
+
for name in sorted(vars(args)):
|
121 |
+
val = getattr(args, name)
|
122 |
+
print(f" {name}: {val}")
|
123 |
+
|
124 |
+
args.gpu = 0
|
125 |
+
torch.cuda.set_device(args.gpu)
|
126 |
+
|
127 |
+
# Initialize the model.
|
128 |
+
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
|
129 |
+
print('Loading vision model config from', vision_model_config_file)
|
130 |
+
assert os.path.exists(vision_model_config_file)
|
131 |
+
|
132 |
+
text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
|
133 |
+
print('Loading text model config from', text_model_config_file)
|
134 |
+
assert os.path.exists(text_model_config_file)
|
135 |
+
|
136 |
+
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
|
137 |
+
model_info = json.load(fv)
|
138 |
+
if isinstance(model_info['vision_layers'], str):
|
139 |
+
model_info['vision_layers'] = eval(model_info['vision_layers'])
|
140 |
+
for k, v in json.load(ft).items():
|
141 |
+
model_info[k] = v
|
142 |
+
|
143 |
+
model = CLIP(**model_info)
|
144 |
+
convert_weights(model)
|
145 |
+
|
146 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
147 |
+
if args.precision == "amp" or args.precision == "fp32":
|
148 |
+
convert_models_to_fp32(model)
|
149 |
+
model.cuda(args.gpu)
|
150 |
+
if args.precision == "fp16":
|
151 |
+
convert_weights(model)
|
152 |
+
|
153 |
+
# Get imagenet eval data.
|
154 |
+
print("Preparing imagenet val dataset.")
|
155 |
+
data = {}
|
156 |
+
data["imagenet-val"] = get_imagenet_dataset(args, image_transform(model_info['image_resolution']), "val")
|
157 |
+
|
158 |
+
# Resume from a checkpoint.
|
159 |
+
print("Begin to load model checkpoint from {}.".format(args.resume))
|
160 |
+
assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
|
161 |
+
# Map model to be loaded to specified single gpu.
|
162 |
+
loc = "cuda:{}".format(args.gpu)
|
163 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
164 |
+
start_epoch = checkpoint["epoch"]
|
165 |
+
sd = checkpoint["state_dict"]
|
166 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
167 |
+
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
|
168 |
+
model.load_state_dict(sd)
|
169 |
+
print(
|
170 |
+
f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
|
171 |
+
)
|
172 |
+
|
173 |
+
# Compute ensembled class embeddings
|
174 |
+
print('Building zero-shot classifier')
|
175 |
+
|
176 |
+
model.eval()
|
177 |
+
|
178 |
+
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
|
179 |
+
|
180 |
+
# Make inference and evaluation
|
181 |
+
print('Using classifier')
|
182 |
+
results = {}
|
183 |
+
top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
|
184 |
+
results['imagenet-zeroshot-val-top1'] = top1
|
185 |
+
results['imagenet-zeroshot-val-top5'] = top5
|
186 |
+
|
187 |
+
print('Result:')
|
188 |
+
print(", ".join(["{}: {}".format(k, v) for k, v in results.items()]))
|
189 |
+
print('Finished.')
|
Model/CLIP/cn_clip/preprocess/__init__.py
ADDED
File without changes
|