kunyi commited on
Commit
f76d30f
1 Parent(s): 330aea7

Upload 30 files

Browse files
README.md CHANGED
@@ -1,3 +1,279 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [**中文说明**](README_CN.md) | [**English**](README.md)
2
+ # Introduction
3
+ <br><br>
4
+ This project aims to provide a better Chinese CLIP model. The training data used in this project consists of publicly accessible image URLs and related Chinese text descriptions, totaling 400 million. After screening, we ultimately used 100 million data for training.
5
+ This project is produced by QQ-ARC Joint Lab, Tencent PCG.
6
+ <br><br>
7
+
8
+ # Models and Results
9
+ <span id="model_card"></span>
10
+ ## Model Card
11
+ QA-CLIP currently has three different open-source models of different sizes, and their model information and download links are shown in the table below:
12
+ <table border="1" width="100%">
13
+ <tr align="center">
14
+ <th>Model</th><th>Ckp</th><th>Params</th><th>Vision</th><th>Params of Vision</th><th>Text</th><th>Params of Text</th><th>Resolution</th>
15
+ </tr>
16
+ <tr align="center">
17
+ <td>QA-CLIP<sub>RN50</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt">Download</a></td><td>77M</td><td>ResNet50</td><td>38M</td><td>RBT3</td><td>39M</td><td>224</td>
18
+ </tr>
19
+ <tr align="center">
20
+ <td>QA-CLIP<sub>ViT-B/16</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt">Download</a></td><td>188M</td><td>ViT-B/16</td><td>86M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
21
+ </tr>
22
+ <tr align="center">
23
+ <td>QA-CLIP<sub>ViT-L/14</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt">Download</a></td><td>406M</td><td>ViT-L/14</td><td>304M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
24
+ </tr>
25
+ </table>
26
+ <br>
27
+
28
+ ## Results
29
+ We conducted zero-shot tests on [MUGE Retrieval](https://tianchi.aliyun.com/muge), [Flickr30K-CN](https://github.com/li-xirong/cross-lingual-cap), and [COCO-CN](https://github.com/li-xirong/coco-cn) datasets for image-text retrieval tasks. For the image zero-shot classification task, we tested on the ImageNet dataset. The test results are shown in the table below:
30
+
31
+ **Flickr30K-CN Zero-shot Retrieval (Official Test Set)**:
32
+ <table border="1" width="120%">
33
+ <tr align="center">
34
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
35
+ </tr>
36
+ <tr align="center">
37
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
38
+ </tr>
39
+ <tr align="center">
40
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.8</td><td>76.0</td><td>84.6</td><td>60.0</td><td>85.9</td><td>92.0</td>
41
+ </tr>
42
+ <tr align="center">
43
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.5</b></td><td><b>77.4</b></td><td><b>86.1</b></td><td><b>67.1</b></td><td><b>87.9</b></td><td><b>93.2</b></td>
44
+ </tr>
45
+ <tr align="center">
46
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.7</td><td>86.9</td><td>92.8</td><td>74.6</td><td>93.5</td><td>97.1</td>
47
+ </tr>
48
+ <tr align="center">
49
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>63.8</b></td><td><b>88.0</b></td><td><b>93.2</b></td><td><b>78.4</b></td><td><b>96.1</b></td><td><b>98.5</b></td>
50
+ </tr>
51
+ <tr align="center">
52
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>68.0</td><td>89.7</td><td>94.4</td><td>80.2</td><td>96.6</td><td>98.2</td>
53
+ </tr>
54
+ <tr align="center">
55
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td><b>69.7</b></td><td>90.1</td><td>94.8</td><td>84.8</td><td>97.7</td><td>99.1</td>
56
+ </tr>
57
+ <tr align="center">
58
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>69.3</td><td><b>90.3</b></td><td><b>94.7</b></td><td><b>85.3</b></td><td><b>97.9</b></td><td><b>99.2</b></td>
59
+ </tr>
60
+ </table>
61
+ <br>
62
+
63
+ **MUGE Zero-shot Retrieval (Official Validation Set)**:
64
+ <table border="1" width="120%">
65
+ <tr align="center">
66
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
67
+ </tr>
68
+ <tr align="center">
69
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
70
+ </tr>
71
+ <tr align="center">
72
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>42.6</td><td>68.5</td><td>78.0</td><td>30.0</td><td>56.2</td><td>66.9</td>
73
+ </tr>
74
+ <tr align="center">
75
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>44.0</b></td><td><b>69.9</b></td><td><b>79.5</b></td><td><b>32.4</b></td><td><b>59.5</b></td><td><b>70.3</b></td>
76
+ </tr>
77
+ <tr align="center">
78
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>52.1</td><td>76.7</td><td>84.4</td><td>38.7</td><td>65.6</td><td>75.1</td>
79
+ </tr>
80
+ <tr align="center">
81
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>53.2</b></td><td><b>77.7</b></td><td><b>85.1</b></td><td><b>40.7</b></td><td><b>68.2</b></td><td><b>77.2</b></td>
82
+ </tr>
83
+ <tr align="center">
84
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>56.4</td><td>79.8</td><td>86.2</td><td>42.6</td><td>69.8</td><td>78.6</td>
85
+ </tr>
86
+ <tr align="center">
87
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td>29.6</td><td>49.9</td><td>58.8</td><td>21.4</td><td>42.0</td><td>51.9</td>
88
+ </tr>
89
+ <tr align="center">
90
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>57.4</b></td><td><b>81.0</b></td><td><b>87.7</b></td><td><b>45.5</b></td><td><b>73.0</b></td><td><b>81.4</b></td>
91
+ </tr>
92
+ </table>
93
+ <br>
94
+
95
+ **COCO-CN Zero-shot Retrieval (Official Test Set)**:
96
+ <table border="1" width="120%">
97
+ <tr align="center">
98
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
99
+ </tr>
100
+ <tr align="center">
101
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
102
+ </tr>
103
+ <tr align="center">
104
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.1</td><td>81.3</td><td>90.5</td><td>50.9</td><td>81.1</td><td>90.5</td>
105
+ </tr>
106
+ <tr align="center">
107
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.1</b></td><td><b>82.5</b></td><td><b>91.7</b></td><td><b>56.7</b></td><td><b>85.2</b></td><td><b>92.9</b></td>
108
+ </tr>
109
+ <tr align="center">
110
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.2</td><td>87.1</td><td>94.9</td><td>56.3</td><td>84.0</td><td>93.3</td>
111
+ </tr>
112
+ <tr align="center">
113
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>62.9</b></td><td><b>87.7</b></td><td><b>94.7</b></td><td><b>61.5</b></td><td><b>87.6</b></td><td><b>94.8</b></td>
114
+ </tr>
115
+ <tr align="center">
116
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>64.9</td><td>88.8</td><td>94.2</td><td>60.6</td><td>84.4</td><td>93.1</td>
117
+ </tr>
118
+ <tr align="center">
119
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td>63.5</td><td>87.6</td><td>93.5</td><td>62.6</td><td><b>88.5</b></td><td><b>95.9</b></td>
120
+ </tr>
121
+ <tr align="center">
122
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>65.7</b></td><td><b>90.2</b></td><td><b>95.0</b></td><td><b>64.5</b></td><td>88.3</td><td>95.1</td>
123
+ </tr>
124
+ </table>
125
+ <br>
126
+
127
+ **Zero-shot Image Classification on ImageNet**:
128
+ <table border="1" width="120%">
129
+ <tr align="center">
130
+ <th>Task</th><th colspan="1">ImageNet</th>
131
+ </tr>
132
+ <tr align="center">
133
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>33.5</td>
134
+ </tr>
135
+ <tr align="center">
136
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>35.5</b></td>
137
+ </tr>
138
+ <tr align="center">
139
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>48.4</td>
140
+ </tr>
141
+ <tr align="center">
142
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>49.7</b></td>
143
+ </tr>
144
+ <tr align="center">
145
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>54.7</td>
146
+ </tr>
147
+ <tr align="center">
148
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>55.8</b></td>
149
+ </tr>
150
+ </table>
151
+ <br>
152
+
153
+ <br><br>
154
+
155
+
156
+ # Getting Started
157
+ ## Installation Requirements
158
+ Environment configuration requirements:
159
+
160
+ * python >= 3.6.4
161
+ * pytorch >= 1.8.0 (with torchvision >= 0.9.0)
162
+ * CUDA Version >= 10.2
163
+
164
+ Install required packages:
165
+ ```bash
166
+ cd /yourpath/QA-CLIP-main
167
+ pip install -r requirements.txt
168
+ ```
169
+
170
+ ## Inference Code
171
+ ```bash
172
+ export PYTHONPATH=/yourpath/QA-CLIP-main
173
+ ```
174
+ Inference code example:
175
+ ```python
176
+ import torch
177
+ from PIL import Image
178
+
179
+ import clip as clip
180
+ from clip import load_from_name, available_models
181
+ print("Available models:", available_models())
182
+ # Available models: ['ViT-B-16', 'ViT-L-14', 'RN50']
183
+
184
+ device = "cuda" if torch.cuda.is_available() else "cpu"
185
+ model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
186
+ model.eval()
187
+ image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
188
+ text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
189
+
190
+ with torch.no_grad():
191
+ image_features = model.encode_image(image)
192
+ text_features = model.encode_text(text)
193
+ # Normalize the features. Please use the normalized features for downstream tasks.
194
+ image_features /= image_features.norm(dim=-1, keepdim=True)
195
+ text_features /= text_features.norm(dim=-1, keepdim=True)
196
+
197
+ logits_per_image, logits_per_text = model.get_similarity(image, text)
198
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
199
+
200
+ print("Label probs:", probs)
201
+ ```
202
+ <br><br>
203
+
204
+ ## Prediction and Evaluation
205
+
206
+ ### Download Image-text Retrieval Test Dataset
207
+ In Project <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>, the test set has already been preprocessed. Here is the download link they provided:
208
+
209
+ MUGE dataset:[download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)
210
+
211
+ Flickr30K-CN dataset:[download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)
212
+
213
+ Additionally, obtaining the [COCO-CN](https://github.com/li-xirong/coco-cn) dataset requires applying to the original author.
214
+
215
+ ### Download ImageNet Dataset
216
+ Please download the raw data yourself,[Chinese Label](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label_cn.txt) and [English Label](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label.txt) are provided by Project <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>
217
+ ### Image-text Retrieval Evaluation
218
+ The image-text retrieval evaluation code can be referred to as follows:
219
+ ```bash
220
+ split=test # Designate the computation of features for the valid or test set
221
+ resume=your_ckp_path
222
+ DATAPATH=your_DATAPATH
223
+ dataset_name=Flickr30k-CN
224
+ # dataset_name=MUGE
225
+
226
+ python -u eval/extract_features.py \
227
+ --extract-image-feats \
228
+ --extract-text-feats \
229
+ --image-data="${DATAPATH}/datasets/${dataset_name}/lmdb/${split}/imgs" \
230
+ --text-data="${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl" \
231
+ --img-batch-size=32 \
232
+ --text-batch-size=32 \
233
+ --context-length=52 \
234
+ --resume=${resume} \
235
+ --vision-model=ViT-B-16 \
236
+ --text-model=RoBERTa-wwm-ext-base-chinese
237
+
238
+ python -u eval/make_topk_predictions.py \
239
+ --image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
240
+ --text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
241
+ --top-k=10 \
242
+ --eval-batch-size=32768 \
243
+ --output="${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl"
244
+
245
+ python -u eval/make_topk_predictions_tr.py \
246
+ --image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
247
+ --text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
248
+ --top-k=10 \
249
+ --eval-batch-size=32768 \
250
+ --output="${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl"
251
+
252
+ python eval/evaluation.py \
253
+ ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl \
254
+ ${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl \
255
+ ${DATAPATH}/datasets/${dataset_name}/output1.json
256
+ cat ${DATAPATH}/datasets/${dataset_name}/output1.json
257
+
258
+ python eval/transform_ir_annotation_to_tr.py \
259
+ --input ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl
260
+
261
+ python eval/evaluation_tr.py \
262
+ ${DATAPATH}/datasets/${dataset_name}/${split}_texts.tr.jsonl \
263
+ ${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl \
264
+ ${DATAPATH}/datasets/${dataset_name}/output2.json
265
+ cat ${DATAPATH}/datasets/${dataset_name}/output2.json
266
+ ```
267
+
268
+ ### ImageNet Zero-shot Classification
269
+ The ImageNet zero-shot classification code can be referred to as follows
270
+ ```bash
271
+ bash scripts/zeroshot_eval.sh 0 \
272
+ ${DATAPATH} imagenet \
273
+ ViT-B-16 RoBERTa-wwm-ext-base-chinese \
274
+ ./pretrained_weights/QA-CLIP-base.pt
275
+ ```
276
+ # Acknowledgments
277
+ <br><br>
278
+ The project code is based on implementation of <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>, and we are very grateful for their outstanding open-source contributions.
279
+ <br><br>
README_CN.md ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [**中文说明**](README_CN.md) | [**English**](README.md)
2
+ # 项目介绍
3
+ <br><br>
4
+ 本项目旨在提供更好的中文CLIP模型。该项目使用的训练数据均为公开可访问的图像URL及相关中文文本描述,总量达到400M。经过筛选后,我们最终使用了100M的数据进行训练。
5
+ 本项目于QQ-ARC Joint Lab, Tencent PCG完成
6
+ <br><br>
7
+
8
+ # 模型及实验
9
+ <span id="model_card"></span>
10
+ ## 模型规模 & 下载链接
11
+ QA-CLIP目前开源3个不同规模,其模型信息和下载方式见下表:
12
+
13
+ <table border="1" width="100%">
14
+ <tr align="center">
15
+ <th>模型规模</th><th>下载链接</th><th>参数量</th><th>视觉侧骨架</th><th>视觉侧参数量</th><th>文本侧骨架</th><th>文本侧参数量</th><th>分辨率</th>
16
+ </tr>
17
+ <tr align="center">
18
+ <td>QA-CLIP<sub>RN50</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt">Download</a></td><td>77M</td><td>ResNet50</td><td>38M</td><td>RBT3</td><td>39M</td><td>224</td>
19
+ </tr>
20
+ <tr align="center">
21
+ <td>QA-CLIP<sub>ViT-B/16</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt">Download</a></td><td>188M</td><td>ViT-B/16</td><td>86M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
22
+ </tr>
23
+ <tr align="center">
24
+ <td>QA-CLIP<sub>ViT-L/14</sub></td><td><a href="https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt">Download</a></td><td>406M</td><td>ViT-L/14</td><td>304M</td><td>RoBERTa-wwm-Base</td><td>102M</td><td>224</td>
25
+ </tr>
26
+ </table>
27
+ <br>
28
+
29
+ ## 实验结果
30
+ 针对图文检索任务,我们在[MUGE Retrieval](https://tianchi.aliyun.com/muge)、[Flickr30K-CN](https://github.com/li-xirong/cross-lingual-cap)和[COCO-CN](https://github.com/li-xirong/coco-cn)上进行了zero-shot测试。
31
+ 针对图像零样本分类任务,我们在ImageNet数据集上进行了测试。测试结果见下表:
32
+
33
+
34
+ **Flickr30K-CN Zero-shot Retrieval (Official Test Set)**:
35
+ <table border="1" width="120%">
36
+ <tr align="center">
37
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
38
+ </tr>
39
+ <tr align="center">
40
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
41
+ </tr>
42
+ <tr align="center">
43
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.8</td><td>76.0</td><td>84.6</td><td>60.0</td><td>85.9</td><td>92.0</td>
44
+ </tr>
45
+ <tr align="center">
46
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.5</b></td><td><b>77.4</b></td><td><b>86.1</b></td><td><b>67.1</b></td><td><b>87.9</b></td><td><b>93.2</b></td>
47
+ </tr>
48
+ <tr align="center">
49
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.7</td><td>86.9</td><td>92.8</td><td>74.6</td><td>93.5</td><td>97.1</td>
50
+ </tr>
51
+ <tr align="center">
52
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>63.8</b></td><td><b>88.0</b></td><td><b>93.2</b></td><td><b>78.4</b></td><td><b>96.1</b></td><td><b>98.5</b></td>
53
+ </tr>
54
+ <tr align="center">
55
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>68.0</td><td>89.7</td><td>94.4</td><td>80.2</td><td>96.6</td><td>98.2</td>
56
+ </tr>
57
+ <tr align="center">
58
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td><b>69.7</b></td><td>90.1</td><td>94.8</td><td>84.8</td><td>97.7</td><td>99.1</td>
59
+ </tr>
60
+ <tr align="center">
61
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>69.3</td><td><b>90.3</b></td><td><b>94.7</b></td><td><b>85.3</b></td><td><b>97.9</b></td><td><b>99.2</b></td>
62
+ </tr>
63
+ </table>
64
+ <br>
65
+
66
+ **MUGE Zero-shot Retrieval (Official Validation Set)**:
67
+ <table border="1" width="120%">
68
+ <tr align="center">
69
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
70
+ </tr>
71
+ <tr align="center">
72
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
73
+ </tr>
74
+ <tr align="center">
75
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>42.6</td><td>68.5</td><td>78.0</td><td>30.0</td><td>56.2</td><td>66.9</td>
76
+ </tr>
77
+ <tr align="center">
78
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>44.0</b></td><td><b>69.9</b></td><td><b>79.5</b></td><td><b>32.4</b></td><td><b>59.5</b></td><td><b>70.3</b></td>
79
+ </tr>
80
+ <tr align="center">
81
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>52.1</td><td>76.7</td><td>84.4</td><td>38.7</td><td>65.6</td><td>75.1</td>
82
+ </tr>
83
+ <tr align="center">
84
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>53.2</b></td><td><b>77.7</b></td><td><b>85.1</b></td><td><b>40.7</b></td><td><b>68.2</b></td><td><b>77.2</b></td>
85
+ </tr>
86
+ <tr align="center">
87
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>56.4</td><td>79.8</td><td>86.2</td><td>42.6</td><td>69.8</td><td>78.6</td>
88
+ </tr>
89
+ <tr align="center">
90
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td>29.6</td><td>49.9</td><td>58.8</td><td>21.4</td><td>42.0</td><td>51.9</td>
91
+ </tr>
92
+ <tr align="center">
93
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>57.4</b></td><td><b>81.0</b></td><td><b>87.7</b></td><td><b>45.5</b></td><td><b>73.0</b></td><td><b>81.4</b></td>
94
+ </tr>
95
+ </table>
96
+ <br>
97
+
98
+ **COCO-CN Zero-shot Retrieval (Official Test Set)**:
99
+ <table border="1" width="120%">
100
+ <tr align="center">
101
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
102
+ </tr>
103
+ <tr align="center">
104
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
105
+ </tr>
106
+ <tr align="center">
107
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>48.1</td><td>81.3</td><td>90.5</td><td>50.9</td><td>81.1</td><td>90.5</td>
108
+ </tr>
109
+ <tr align="center">
110
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>50.1</b></td><td><b>82.5</b></td><td><b>91.7</b></td><td><b>56.7</b></td><td><b>85.2</b></td><td><b>92.9</b></td>
111
+ </tr>
112
+ <tr align="center">
113
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>62.2</td><td>87.1</td><td>94.9</td><td>56.3</td><td>84.0</td><td>93.3</td>
114
+ </tr>
115
+ <tr align="center">
116
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>62.9</b></td><td><b>87.7</b></td><td><b>94.7</b></td><td><b>61.5</b></td><td><b>87.6</b></td><td><b>94.8</b></td>
117
+ </tr>
118
+ <tr align="center">
119
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>64.9</td><td>88.8</td><td>94.2</td><td>60.6</td><td>84.4</td><td>93.1</td>
120
+ </tr>
121
+ <tr align="center">
122
+ <td width="120%">AltClip<sub>ViT-L/14</sub></td><td>63.5</td><td>87.6</td><td>93.5</td><td>62.6</td><td><b>88.5</b></td><td><b>95.9</b></td>
123
+ </tr>
124
+ <tr align="center">
125
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>65.7</b></td><td><b>90.2</b></td><td><b>95.0</b></td><td><b>64.5</b></td><td>88.3</td><td>95.1</td>
126
+ </tr>
127
+ </table>
128
+ <br>
129
+
130
+ **Zero-shot Image Classification on ImageNet**:
131
+ <table border="1" width="120%">
132
+ <tr align="center">
133
+ <th>Task</th><th colspan="1">ImageNet</th>
134
+ </tr>
135
+ <tr align="center">
136
+ <td width="120%">CN-CLIP<sub>RN50</sub></td><td>33.5</td>
137
+ </tr>
138
+ <tr align="center">
139
+ <td width="120%">QA-CLIP<sub>RN50</sub></td><td><b>35.5</b></td>
140
+ </tr>
141
+ <tr align="center">
142
+ <td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>48.4</td>
143
+ </tr>
144
+ <tr align="center">
145
+ <td width="120%">QA-CLIP<sub>ViT-B/16</sub></td><td><b>49.7</b></td>
146
+ </tr>
147
+ <tr align="center">
148
+ <td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>54.7</td>
149
+ </tr>
150
+ <tr align="center">
151
+ <td width="120%">QA-CLIP<sub>ViT-L/14</sub></td><td><b>55.8</b></td>
152
+ </tr>
153
+ </table>
154
+ <br>
155
+
156
+ <br><br>
157
+
158
+
159
+ # 使用教程
160
+ ## 安装要求
161
+ 环境配置要求:
162
+
163
+ * python >= 3.6.4
164
+ * pytorch >= 1.8.0 (with torchvision >= 0.9.0)
165
+ * CUDA Version >= 10.2
166
+
167
+ 安装本项目所需库
168
+ ```bash
169
+ cd /yourpath/QA-CLIP-main
170
+ pip install -r requirements.txt
171
+ ```
172
+
173
+ ## 推理代码
174
+ ```bash
175
+ export PYTHONPATH=/yourpath/QA-CLIP-main
176
+ ```
177
+ 推理代码示例:
178
+ ```python
179
+ import torch
180
+ from PIL import Image
181
+
182
+ import clip as clip
183
+ from clip import load_from_name, available_models
184
+ print("Available models:", available_models())
185
+ # Available models: ['ViT-B-16', 'ViT-L-14', 'RN50']
186
+
187
+ device = "cuda" if torch.cuda.is_available() else "cpu"
188
+ model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
189
+ model.eval()
190
+ image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
191
+ text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
192
+
193
+ with torch.no_grad():
194
+ image_features = model.encode_image(image)
195
+ text_features = model.encode_text(text)
196
+ # 对特征进行归一化,请使用归一化后的图文特征用于下游任务
197
+ image_features /= image_features.norm(dim=-1, keepdim=True)
198
+ text_features /= text_features.norm(dim=-1, keepdim=True)
199
+
200
+ logits_per_image, logits_per_text = model.get_similarity(image, text)
201
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
202
+
203
+ print("Label probs:", probs)
204
+ ```
205
+ <br><br>
206
+
207
+ ## 预测及评估
208
+
209
+ ### 图文检索测试数据集下载
210
+ <b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>项目中已经预处理好测试集,这是他们提供的下载链接:
211
+
212
+ MUGE数据:[下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)
213
+
214
+ Flickr30K-CN数据:[下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)
215
+
216
+ 另外[COCO-CN](https://github.com/li-xirong/coco-cn)数据的获取需要向原作者进行申请
217
+ ### ImageNet数据集下载
218
+ 原始数据请自行下载,[中文标签](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label_cn.txt)和[英文标签](http://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/ImageNet-1K/label.txt)同样由<b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>项目提供
219
+ ### 图文检索评估
220
+ 图文检索评估代码可以参考如下:
221
+ ```bash
222
+ split=test # 指定计算valid或test集特征
223
+ resume=your_ckp_path
224
+ DATAPATH=your_DATAPATH
225
+ dataset_name=Flickr30k-CN
226
+ # dataset_name=MUGE
227
+
228
+ python -u eval/extract_features.py \
229
+ --extract-image-feats \
230
+ --extract-text-feats \
231
+ --image-data="${DATAPATH}/datasets/${dataset_name}/lmdb/${split}/imgs" \
232
+ --text-data="${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl" \
233
+ --img-batch-size=32 \
234
+ --text-batch-size=32 \
235
+ --context-length=52 \
236
+ --resume=${resume} \
237
+ --vision-model=ViT-B-16 \
238
+ --text-model=RoBERTa-wwm-ext-base-chinese
239
+
240
+ python -u eval/make_topk_predictions.py \
241
+ --image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
242
+ --text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
243
+ --top-k=10 \
244
+ --eval-batch-size=32768 \
245
+ --output="${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl"
246
+
247
+ python -u eval/make_topk_predictions_tr.py \
248
+ --image-feats="${DATAPATH}/datasets/${dataset_name}/${split}_imgs.img_feat.jsonl" \
249
+ --text-feats="${DATAPATH}/datasets/${dataset_name}/${split}_texts.txt_feat.jsonl" \
250
+ --top-k=10 \
251
+ --eval-batch-size=32768 \
252
+ --output="${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl"
253
+
254
+ python eval/evaluation.py \
255
+ ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl \
256
+ ${DATAPATH}/datasets/${dataset_name}/${split}_predictions.jsonl \
257
+ ${DATAPATH}/datasets/${dataset_name}/output1.json
258
+ cat ${DATAPATH}/datasets/${dataset_name}/output1.json
259
+
260
+ python eval/transform_ir_annotation_to_tr.py \
261
+ --input ${DATAPATH}/datasets/${dataset_name}/${split}_texts.jsonl
262
+
263
+ python eval/evaluation_tr.py \
264
+ ${DATAPATH}/datasets/${dataset_name}/${split}_texts.tr.jsonl \
265
+ ${DATAPATH}/datasets/${dataset_name}/${split}_tr_predictions.jsonl \
266
+ ${DATAPATH}/datasets/${dataset_name}/output2.json
267
+ cat ${DATAPATH}/datasets/${dataset_name}/output2.json
268
+ ```
269
+
270
+ ### ImageNet零样本分类
271
+ ImageNet零样本分类的代码参考如下
272
+ ```bash
273
+ bash scripts/zeroshot_eval.sh 0 \
274
+ ${DATAPATH} imagenet \
275
+ ViT-B-16 RoBERTa-wwm-ext-base-chinese \
276
+ ./pretrained_weights/QA-CLIP-base.pt
277
+ ```
278
+ # 致谢
279
+ <br><br>
280
+ 项目代码基于<b>[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)</b>实现,非常感谢他们优秀的开源工作。
281
+ <br><br>
clip/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .bert_tokenizer import FullTokenizer
2
+
3
+ _tokenizer = FullTokenizer()
4
+ from .model import convert_state_dict
5
+ from .utils import load_from_name, available_models, tokenize, image_transform, load
clip/bert_tokenizer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Tokenization classes."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import collections
23
+ import re
24
+ import unicodedata
25
+ import six
26
+ from functools import lru_cache
27
+ import os
28
+
29
+ @lru_cache()
30
+ def default_vocab():
31
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt")
32
+
33
+ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
34
+ """Checks whether the casing config is consistent with the checkpoint name."""
35
+
36
+ # The casing has to be passed in by the user and there is no explicit check
37
+ # as to whether it matches the checkpoint. The casing information probably
38
+ # should have been stored in the bert_config.json file, but it's not, so
39
+ # we have to heuristically detect it to validate.
40
+
41
+ if not init_checkpoint:
42
+ return
43
+
44
+ m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
45
+ if m is None:
46
+ return
47
+
48
+ model_name = m.group(1)
49
+
50
+ lower_models = [
51
+ "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
52
+ "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
53
+ ]
54
+
55
+ cased_models = [
56
+ "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
57
+ "multi_cased_L-12_H-768_A-12"
58
+ ]
59
+
60
+ is_bad_config = False
61
+ if model_name in lower_models and not do_lower_case:
62
+ is_bad_config = True
63
+ actual_flag = "False"
64
+ case_name = "lowercased"
65
+ opposite_flag = "True"
66
+
67
+ if model_name in cased_models and do_lower_case:
68
+ is_bad_config = True
69
+ actual_flag = "True"
70
+ case_name = "cased"
71
+ opposite_flag = "False"
72
+
73
+ if is_bad_config:
74
+ raise ValueError(
75
+ "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
76
+ "However, `%s` seems to be a %s model, so you "
77
+ "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
78
+ "how the model was pre-training. If this error is wrong, please "
79
+ "just comment out this check." % (actual_flag, init_checkpoint,
80
+ model_name, case_name, opposite_flag))
81
+
82
+
83
+ def convert_to_unicode(text):
84
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
85
+ if six.PY3:
86
+ if isinstance(text, str):
87
+ return text
88
+ elif isinstance(text, bytes):
89
+ return text.decode("utf-8", "ignore")
90
+ else:
91
+ raise ValueError("Unsupported string type: %s" % (type(text)))
92
+ elif six.PY2:
93
+ if isinstance(text, str):
94
+ return text.decode("utf-8", "ignore")
95
+ elif isinstance(text, unicode):
96
+ return text
97
+ else:
98
+ raise ValueError("Unsupported string type: %s" % (type(text)))
99
+ else:
100
+ raise ValueError("Not running on Python2 or Python 3?")
101
+
102
+
103
+ def printable_text(text):
104
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
105
+
106
+ # These functions want `str` for both Python2 and Python3, but in one case
107
+ # it's a Unicode string and in the other it's a byte string.
108
+ if six.PY3:
109
+ if isinstance(text, str):
110
+ return text
111
+ elif isinstance(text, bytes):
112
+ return text.decode("utf-8", "ignore")
113
+ else:
114
+ raise ValueError("Unsupported string type: %s" % (type(text)))
115
+ elif six.PY2:
116
+ if isinstance(text, str):
117
+ return text
118
+ elif isinstance(text, unicode):
119
+ return text.encode("utf-8")
120
+ else:
121
+ raise ValueError("Unsupported string type: %s" % (type(text)))
122
+ else:
123
+ raise ValueError("Not running on Python2 or Python 3?")
124
+
125
+
126
+ def load_vocab(vocab_file):
127
+ """Loads a vocabulary file into a dictionary."""
128
+ vocab = collections.OrderedDict()
129
+ index = 0
130
+ with open(vocab_file, "r", encoding="utf-8") as reader:
131
+ while True:
132
+ token = convert_to_unicode(reader.readline())
133
+ if not token:
134
+ break
135
+ token = token.strip()
136
+ vocab[token] = index
137
+ index += 1
138
+ return vocab
139
+
140
+
141
+ def convert_by_vocab(vocab, items):
142
+ """Converts a sequence of [tokens|ids] using the vocab."""
143
+ output = []
144
+ for item in items:
145
+ output.append(vocab[item])
146
+ return output
147
+
148
+
149
+ def convert_tokens_to_ids(vocab, tokens):
150
+ return convert_by_vocab(vocab, tokens)
151
+
152
+
153
+ def convert_ids_to_tokens(inv_vocab, ids):
154
+ return convert_by_vocab(inv_vocab, ids)
155
+
156
+
157
+ def whitespace_tokenize(text):
158
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
159
+ text = text.strip()
160
+ if not text:
161
+ return []
162
+ tokens = text.split()
163
+ return tokens
164
+
165
+
166
+ class FullTokenizer(object):
167
+ """Runs end-to-end tokenziation."""
168
+
169
+ def __init__(self, vocab_file=default_vocab(), do_lower_case=True):
170
+ self.vocab = load_vocab(vocab_file)
171
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
172
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
173
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
174
+
175
+ def tokenize(self, text):
176
+ split_tokens = []
177
+ for token in self.basic_tokenizer.tokenize(text):
178
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
179
+ split_tokens.append(sub_token)
180
+
181
+ return split_tokens
182
+
183
+ def convert_tokens_to_ids(self, tokens):
184
+ return convert_by_vocab(self.vocab, tokens)
185
+
186
+ def convert_ids_to_tokens(self, ids):
187
+ return convert_by_vocab(self.inv_vocab, ids)
188
+
189
+ @staticmethod
190
+ def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
191
+ """ Converts a sequence of tokens (string) in a single string. """
192
+
193
+ def clean_up_tokenization(out_string):
194
+ """ Clean up a list of simple English tokenization artifacts
195
+ like spaces before punctuations and abreviated forms.
196
+ """
197
+ out_string = (
198
+ out_string.replace(" .", ".")
199
+ .replace(" ?", "?")
200
+ .replace(" !", "!")
201
+ .replace(" ,", ",")
202
+ .replace(" ' ", "'")
203
+ .replace(" n't", "n't")
204
+ .replace(" 'm", "'m")
205
+ .replace(" 's", "'s")
206
+ .replace(" 've", "'ve")
207
+ .replace(" 're", "'re")
208
+ )
209
+ return out_string
210
+
211
+ text = ' '.join(tokens).replace(' ##', '').strip()
212
+ if clean_up_tokenization_spaces:
213
+ clean_text = clean_up_tokenization(text)
214
+ return clean_text
215
+ else:
216
+ return text
217
+
218
+ def vocab_size(self):
219
+ return len(self.vocab)
220
+
221
+
222
+ class BasicTokenizer(object):
223
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
224
+
225
+ def __init__(self, do_lower_case=True):
226
+ """Constructs a BasicTokenizer.
227
+
228
+ Args:
229
+ do_lower_case: Whether to lower case the input.
230
+ """
231
+ self.do_lower_case = do_lower_case
232
+
233
+ def tokenize(self, text):
234
+ """Tokenizes a piece of text."""
235
+ text = convert_to_unicode(text)
236
+ text = self._clean_text(text)
237
+
238
+ # This was added on November 1st, 2018 for the multilingual and Chinese
239
+ # models. This is also applied to the English models now, but it doesn't
240
+ # matter since the English models were not trained on any Chinese data
241
+ # and generally don't have any Chinese data in them (there are Chinese
242
+ # characters in the vocabulary because Wikipedia does have some Chinese
243
+ # words in the English Wikipedia.).
244
+ text = self._tokenize_chinese_chars(text)
245
+
246
+ orig_tokens = whitespace_tokenize(text)
247
+ split_tokens = []
248
+ for token in orig_tokens:
249
+ if self.do_lower_case:
250
+ token = token.lower()
251
+ token = self._run_strip_accents(token)
252
+ split_tokens.extend(self._run_split_on_punc(token))
253
+
254
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
255
+ return output_tokens
256
+
257
+ def _run_strip_accents(self, text):
258
+ """Strips accents from a piece of text."""
259
+ text = unicodedata.normalize("NFD", text)
260
+ output = []
261
+ for char in text:
262
+ cat = unicodedata.category(char)
263
+ if cat == "Mn":
264
+ continue
265
+ output.append(char)
266
+ return "".join(output)
267
+
268
+ def _run_split_on_punc(self, text):
269
+ """Splits punctuation on a piece of text."""
270
+ chars = list(text)
271
+ i = 0
272
+ start_new_word = True
273
+ output = []
274
+ while i < len(chars):
275
+ char = chars[i]
276
+ if _is_punctuation(char):
277
+ output.append([char])
278
+ start_new_word = True
279
+ else:
280
+ if start_new_word:
281
+ output.append([])
282
+ start_new_word = False
283
+ output[-1].append(char)
284
+ i += 1
285
+
286
+ return ["".join(x) for x in output]
287
+
288
+ def _tokenize_chinese_chars(self, text):
289
+ """Adds whitespace around any CJK character."""
290
+ output = []
291
+ for char in text:
292
+ cp = ord(char)
293
+ if self._is_chinese_char(cp):
294
+ output.append(" ")
295
+ output.append(char)
296
+ output.append(" ")
297
+ else:
298
+ output.append(char)
299
+ return "".join(output)
300
+
301
+ def _is_chinese_char(self, cp):
302
+ """Checks whether CP is the codepoint of a CJK character."""
303
+ # This defines a "chinese character" as anything in the CJK Unicode block:
304
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
305
+ #
306
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
307
+ # despite its name. The modern Korean Hangul alphabet is a different block,
308
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
309
+ # space-separated words, so they are not treated specially and handled
310
+ # like the all of the other languages.
311
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
312
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
313
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
314
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
315
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
316
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
317
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
318
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
319
+ return True
320
+
321
+ return False
322
+
323
+ def _clean_text(self, text):
324
+ """Performs invalid character removal and whitespace cleanup on text."""
325
+ output = []
326
+ for char in text:
327
+ cp = ord(char)
328
+ if cp == 0 or cp == 0xfffd or _is_control(char):
329
+ continue
330
+ if _is_whitespace(char):
331
+ output.append(" ")
332
+ else:
333
+ output.append(char)
334
+ return "".join(output)
335
+
336
+
337
+ class WordpieceTokenizer(object):
338
+ """Runs WordPiece tokenziation."""
339
+
340
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
341
+ self.vocab = vocab
342
+ self.unk_token = unk_token
343
+ self.max_input_chars_per_word = max_input_chars_per_word
344
+
345
+ def tokenize(self, text):
346
+ """Tokenizes a piece of text into its word pieces.
347
+
348
+ This uses a greedy longest-match-first algorithm to perform tokenization
349
+ using the given vocabulary.
350
+
351
+ For example:
352
+ input = "unaffable"
353
+ output = ["un", "##aff", "##able"]
354
+
355
+ Args:
356
+ text: A single token or whitespace separated tokens. This should have
357
+ already been passed through `BasicTokenizer.
358
+
359
+ Returns:
360
+ A list of wordpiece tokens.
361
+ """
362
+
363
+ text = convert_to_unicode(text)
364
+
365
+ output_tokens = []
366
+ for token in whitespace_tokenize(text):
367
+ chars = list(token)
368
+ if len(chars) > self.max_input_chars_per_word:
369
+ output_tokens.append(self.unk_token)
370
+ continue
371
+
372
+ is_bad = False
373
+ start = 0
374
+ sub_tokens = []
375
+ while start < len(chars):
376
+ end = len(chars)
377
+ cur_substr = None
378
+ while start < end:
379
+ substr = "".join(chars[start:end])
380
+ if start > 0:
381
+ substr = "##" + substr
382
+ if substr in self.vocab:
383
+ cur_substr = substr
384
+ break
385
+ end -= 1
386
+ if cur_substr is None:
387
+ is_bad = True
388
+ break
389
+ sub_tokens.append(cur_substr)
390
+ start = end
391
+
392
+ if is_bad:
393
+ output_tokens.append(self.unk_token)
394
+ else:
395
+ output_tokens.extend(sub_tokens)
396
+ return output_tokens
397
+
398
+
399
+ def _is_whitespace(char):
400
+ """Checks whether `chars` is a whitespace character."""
401
+ # \t, \n, and \r are technically contorl characters but we treat them
402
+ # as whitespace since they are generally considered as such.
403
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
404
+ return True
405
+ cat = unicodedata.category(char)
406
+ if cat == "Zs":
407
+ return True
408
+ return False
409
+
410
+
411
+ def _is_control(char):
412
+ """Checks whether `chars` is a control character."""
413
+ # These are technically control characters but we count them as whitespace
414
+ # characters.
415
+ if char == "\t" or char == "\n" or char == "\r":
416
+ return False
417
+ cat = unicodedata.category(char)
418
+ if cat in ("Cc", "Cf"):
419
+ return True
420
+ return False
421
+
422
+
423
+ def _is_punctuation(char):
424
+ """Checks whether `chars` is a punctuation character."""
425
+ cp = ord(char)
426
+ # We treat all non-letter/number ASCII as punctuation.
427
+ # Characters such as "^", "$", and "`" are not in the Unicode
428
+ # Punctuation class but we treat them as punctuation anyways, for
429
+ # consistency.
430
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
431
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
432
+ return True
433
+ cat = unicodedata.category(char)
434
+ if cat.startswith("P"):
435
+ return True
436
+ return False
clip/configuration_bert.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class BertConfig(object):
26
+ r"""
27
+ :class:`~transformers.BertConfig` is the configuration class to store the configuration of a
28
+ `BertModel`.
29
+
30
+
31
+ Arguments:
32
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
33
+ hidden_size: Size of the encoder layers and the pooler layer.
34
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads: Number of attention heads for each attention layer in
36
+ the Transformer encoder.
37
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
38
+ layer in the Transformer encoder.
39
+ hidden_act: The non-linear activation function (function or string) in the
40
+ encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
41
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
42
+ layers in the embeddings, encoder, and pooler.
43
+ attention_probs_dropout_prob: The dropout ratio for the attention
44
+ probabilities.
45
+ max_position_embeddings: The maximum sequence length that this model might
46
+ ever be used with. Typically set this to something large just in case
47
+ (e.g., 512 or 1024 or 2048).
48
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
49
+ `BertModel`.
50
+ initializer_range: The sttdev of the truncated_normal_initializer for
51
+ initializing all weight matrices.
52
+ layer_norm_eps: The epsilon used by LayerNorm.
53
+ """
54
+
55
+ def __init__(self,
56
+ vocab_size_or_config_json_file=30522,
57
+ hidden_size=768,
58
+ num_hidden_layers=12,
59
+ num_attention_heads=12,
60
+ intermediate_size=3072,
61
+ hidden_act="gelu",
62
+ hidden_dropout_prob=0.1,
63
+ attention_probs_dropout_prob=0.1,
64
+ max_position_embeddings=512,
65
+ type_vocab_size=2,
66
+ initializer_range=0.02,
67
+ layer_norm_eps=1e-12,
68
+ output_attentions=False,
69
+ output_hidden_states=False,
70
+ use_flash_attention=False
71
+ ):
72
+ self.vocab_size = vocab_size_or_config_json_file
73
+ self.hidden_size = hidden_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+ self.hidden_act = hidden_act
77
+ self.intermediate_size = intermediate_size
78
+ self.hidden_dropout_prob = hidden_dropout_prob
79
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
80
+ self.max_position_embeddings = max_position_embeddings
81
+ self.type_vocab_size = type_vocab_size
82
+ self.initializer_range = initializer_range
83
+ self.layer_norm_eps = layer_norm_eps
84
+ self.output_attentions = output_attentions
85
+ self.output_hidden_states = output_hidden_states
86
+ self.use_flash_attention = use_flash_attention
clip/model.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ from itertools import repeat
4
+ import collections.abc
5
+
6
+ import math
7
+ import logging
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.utils.checkpoint import checkpoint
13
+
14
+ import importlib.util
15
+ if importlib.util.find_spec('flash_attn'):
16
+ FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA
17
+
18
+ from clip import _tokenizer
19
+ from clip.configuration_bert import BertConfig
20
+ from clip.modeling_bert import BertModel
21
+
22
+ try:
23
+ from transformers import CLIPTextModelWithProjection
24
+ except:
25
+ pass
26
+
27
+ class Bottleneck(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, stride=1):
31
+ super().__init__()
32
+
33
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
34
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
35
+ self.bn1 = nn.BatchNorm2d(planes)
36
+
37
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
38
+ self.bn2 = nn.BatchNorm2d(planes)
39
+
40
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
41
+
42
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
43
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
44
+
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.downsample = None
47
+ self.stride = stride
48
+
49
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
50
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
51
+ self.downsample = nn.Sequential(OrderedDict([
52
+ ("-1", nn.AvgPool2d(stride)),
53
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
54
+ ("1", nn.BatchNorm2d(planes * self.expansion))
55
+ ]))
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ identity = x
59
+
60
+ out = self.relu(self.bn1(self.conv1(x)))
61
+ out = self.relu(self.bn2(self.conv2(out)))
62
+ out = self.avgpool(out)
63
+ out = self.bn3(self.conv3(out))
64
+
65
+ if self.downsample is not None:
66
+ identity = self.downsample(x)
67
+
68
+ out += identity
69
+ out = self.relu(out)
70
+ return out
71
+
72
+
73
+ class AttentionPool2d(nn.Module):
74
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
75
+ super().__init__()
76
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
77
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
78
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
79
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
80
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
81
+ self.num_heads = num_heads
82
+
83
+ def forward(self, x):
84
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
85
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
86
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
87
+ x, _ = F.multi_head_attention_forward(
88
+ query=x, key=x, value=x,
89
+ embed_dim_to_check=x.shape[-1],
90
+ num_heads=self.num_heads,
91
+ q_proj_weight=self.q_proj.weight,
92
+ k_proj_weight=self.k_proj.weight,
93
+ v_proj_weight=self.v_proj.weight,
94
+ in_proj_weight=None,
95
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
96
+ bias_k=None,
97
+ bias_v=None,
98
+ add_zero_attn=False,
99
+ dropout_p=0,
100
+ out_proj_weight=self.c_proj.weight,
101
+ out_proj_bias=self.c_proj.bias,
102
+ use_separate_proj_weight=True,
103
+ training=self.training,
104
+ need_weights=False
105
+ )
106
+
107
+ return x[0]
108
+
109
+
110
+ class ModifiedResNet(nn.Module):
111
+ """
112
+ A ResNet class that is similar to torchvision's but contains the following changes:
113
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
114
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
115
+ - The final pooling layer is a QKV attention instead of an average pool
116
+ """
117
+
118
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
119
+ super().__init__()
120
+ self.output_dim = output_dim
121
+ self.input_resolution = input_resolution
122
+
123
+ # the 3-layer stem
124
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
125
+ self.bn1 = nn.BatchNorm2d(width // 2)
126
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
127
+ self.bn2 = nn.BatchNorm2d(width // 2)
128
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
129
+ self.bn3 = nn.BatchNorm2d(width)
130
+ self.avgpool = nn.AvgPool2d(2)
131
+ self.relu = nn.ReLU(inplace=True)
132
+
133
+ # residual layers
134
+ self._inplanes = width # this is a *mutable* variable used during construction
135
+ self.layer1 = self._make_layer(width, layers[0])
136
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
137
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
138
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
139
+
140
+ embed_dim = width * 32 # the ResNet feature dimension
141
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
142
+
143
+ def _make_layer(self, planes, blocks, stride=1):
144
+ layers = [Bottleneck(self._inplanes, planes, stride)]
145
+
146
+ self._inplanes = planes * Bottleneck.expansion
147
+ for _ in range(1, blocks):
148
+ layers.append(Bottleneck(self._inplanes, planes))
149
+
150
+ return nn.Sequential(*layers)
151
+
152
+ @torch.jit.ignore
153
+ def set_grad_checkpointing(self, enable=True):
154
+ # FIXME support for non-transformer
155
+ pass
156
+
157
+ def forward(self, x):
158
+ def stem(x):
159
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
160
+ x = self.relu(bn(conv(x)))
161
+ x = self.avgpool(x)
162
+ return x
163
+
164
+ x = x.type(self.conv1.weight.dtype)
165
+ x = stem(x)
166
+ x = self.layer1(x)
167
+ x = self.layer2(x)
168
+ x = self.layer3(x)
169
+ x = self.layer4(x)
170
+ x = self.attnpool(x)
171
+
172
+ return x
173
+
174
+
175
+ class LayerNorm(nn.LayerNorm):
176
+ """Subclass torch's LayerNorm to handle fp16."""
177
+
178
+ def forward(self, x: torch.Tensor):
179
+ orig_type = x.dtype
180
+ ret = super().forward(x.type(torch.float32))
181
+ return ret.type(orig_type)
182
+
183
+
184
+ class QuickGELU(nn.Module):
185
+ def forward(self, x: torch.Tensor):
186
+ return x * torch.sigmoid(1.702 * x)
187
+
188
+
189
+ class ResidualAttentionBlock(nn.Module):
190
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
191
+ super().__init__()
192
+
193
+ self.attn = nn.MultiheadAttention(d_model, n_head) if not use_flash_attention else FlashMHA(d_model, n_head)
194
+ self.ln_1 = LayerNorm(d_model)
195
+ self.mlp = nn.Sequential(OrderedDict([
196
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
197
+ ("gelu", QuickGELU()),
198
+ ("c_proj", nn.Linear(d_model * 4, d_model))
199
+ ]))
200
+ self.ln_2 = LayerNorm(d_model)
201
+ self.attn_mask = attn_mask
202
+ self.use_flash_attention = use_flash_attention
203
+
204
+ def attention(self, x: torch.Tensor):
205
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
206
+ if self.use_flash_attention:
207
+ # Batch first is needed for FlashAttention. See https://github.com/HazyResearch/flash-attention/issues/84 for more information.
208
+ return self.attn(x.transpose(1, 0))[0].transpose(1, 0)
209
+ else:
210
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
211
+
212
+ def forward(self, x: torch.Tensor):
213
+ x = x + self.attention(self.ln_1(x))
214
+ x = x + self.mlp(self.ln_2(x))
215
+ return x
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
220
+ super().__init__()
221
+ self.width = width
222
+ self.layers = layers
223
+ self.grad_checkpointing = False
224
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_flash_attention) for _ in range(layers)])
225
+
226
+ def forward(self, x: torch.Tensor):
227
+ if self.grad_checkpointing and not torch.jit.is_scripting():
228
+ for r in self.resblocks:
229
+ x = checkpoint(r, x)
230
+ return x
231
+ return self.resblocks(x)
232
+
233
+
234
+ class VisualTransformer(nn.Module):
235
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, use_flash_attention: bool = False):
236
+ super().__init__()
237
+ self.input_resolution = input_resolution
238
+ self.grid_size = (self.input_resolution // patch_size, self.input_resolution // patch_size)
239
+ self.output_dim = output_dim
240
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
241
+
242
+ scale = width ** -0.5
243
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
244
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
245
+ self.ln_pre = LayerNorm(width)
246
+
247
+ self.transformer = Transformer(width, layers, heads, use_flash_attention=use_flash_attention)
248
+
249
+ self.ln_post = LayerNorm(width)
250
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
251
+
252
+ @torch.jit.ignore
253
+ def set_grad_checkpointing(self, enable=True):
254
+ self.transformer.grad_checkpointing = enable
255
+
256
+ def random_masking(self, x, mask_ratio):
257
+ N, L, D = x.shape # batch, length, dim
258
+ len_keep = int((L - 1) * (1 - mask_ratio))
259
+
260
+ noise = torch.rand(N, L - 1, device=x.device)
261
+ ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device,
262
+ dtype=int)
263
+ ids_keep = ids_shuffle[:, :len_keep]
264
+
265
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
266
+
267
+ x0 = x[:, 0, :]
268
+ x0 = x0.reshape(N, 1, D)
269
+ x_masked_add = torch.cat([x0, x_masked], axis=1)
270
+ return x_masked_add
271
+
272
+ def forward(self, x: torch.Tensor, mask_ratio: float = 0.0):
273
+ x = self.conv1(x) # shape = [*, width, grid, grid]
274
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
275
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
276
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
277
+ x = x + self.positional_embedding.to(x.dtype)
278
+ if mask_ratio != 0:
279
+ x = self.random_masking(x, mask_ratio)
280
+ x = self.ln_pre(x)
281
+
282
+ x = x.permute(1, 0, 2) # NLD -> LND
283
+ x = self.transformer(x)
284
+ x = x.permute(1, 0, 2) # LND -> NLD
285
+
286
+ x = self.ln_post(x[:, 0, :])
287
+
288
+ if self.proj is not None:
289
+ x = x @ self.proj
290
+
291
+ return x
292
+
293
+
294
+ class CLIP(nn.Module):
295
+ def __init__(self,
296
+ embed_dim: int,
297
+ # vision
298
+ image_resolution: int,
299
+ vision_layers: Union[Tuple[int, int, int, int], int],
300
+ vision_width: int,
301
+ vision_patch_size: int,
302
+ # text
303
+ vocab_size: int,
304
+ text_attention_probs_dropout_prob: float,
305
+ text_hidden_act: str,
306
+ text_hidden_dropout_prob: float,
307
+ text_hidden_size: int,
308
+ text_initializer_range: float,
309
+ text_intermediate_size: int,
310
+ text_max_position_embeddings: int,
311
+ text_num_attention_heads: int,
312
+ text_num_hidden_layers: int,
313
+ text_type_vocab_size: int,
314
+ tokenizer = _tokenizer,
315
+ # vision head width, added this param for ViT-H
316
+ vision_head_width: int = 64,
317
+ use_flash_attention: bool = False,
318
+ ):
319
+ super().__init__()
320
+
321
+ if isinstance(vision_layers, (tuple, list)):
322
+ vision_heads = vision_width * 32 // vision_head_width
323
+ self.visual = ModifiedResNet(
324
+ layers=vision_layers,
325
+ output_dim=embed_dim,
326
+ heads=vision_heads,
327
+ input_resolution=image_resolution,
328
+ width=vision_width
329
+ )
330
+ else:
331
+ vision_heads = vision_width // vision_head_width
332
+ self.visual = VisualTransformer(
333
+ input_resolution=image_resolution,
334
+ patch_size=vision_patch_size,
335
+ width=vision_width,
336
+ layers=vision_layers,
337
+ heads=vision_heads,
338
+ output_dim=embed_dim,
339
+ use_flash_attention=use_flash_attention
340
+ )
341
+
342
+ self.bert_config = BertConfig(
343
+ vocab_size_or_config_json_file=vocab_size,
344
+ hidden_size=text_hidden_size,
345
+ num_hidden_layers=text_num_hidden_layers,
346
+ num_attention_heads=text_num_attention_heads,
347
+ intermediate_size=text_intermediate_size,
348
+ hidden_act=text_hidden_act,
349
+ hidden_dropout_prob=text_hidden_dropout_prob,
350
+ attention_probs_dropout_prob=text_attention_probs_dropout_prob,
351
+ max_position_embeddings=text_max_position_embeddings,
352
+ type_vocab_size=text_type_vocab_size,
353
+ initializer_range=text_initializer_range,
354
+ layer_norm_eps=1e-12,
355
+ use_flash_attention=use_flash_attention
356
+ )
357
+ self.bert = BertModel(self.bert_config)
358
+
359
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
360
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
361
+
362
+ self.tokenizer = tokenizer
363
+
364
+ self.initialize_parameters()
365
+
366
+ def initialize_parameters(self):
367
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
368
+
369
+ if isinstance(self.visual, ModifiedResNet):
370
+ if self.visual.attnpool is not None:
371
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
372
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
373
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
374
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
375
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
376
+
377
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
378
+ for name, param in resnet_block.named_parameters():
379
+ if name.endswith("bn3.weight"):
380
+ nn.init.zeros_(param)
381
+
382
+ if self.text_projection is not None:
383
+ nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
384
+
385
+ @torch.jit.ignore
386
+ def set_grad_checkpointing(self, enable=True):
387
+ self.visual.set_grad_checkpointing(enable)
388
+ self.bert.set_grad_checkpointing(enable)
389
+
390
+ @property
391
+ def dtype(self):
392
+ return self.visual.conv1.weight.dtype
393
+
394
+ def encode_image(self, image, mask_ratio=0):
395
+ if isinstance(self.visual, ModifiedResNet):
396
+ # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
397
+ return self.visual(image.type(self.dtype))
398
+ return self.visual(image.type(self.dtype), mask_ratio)
399
+
400
+ def encode_text(self, text):
401
+ pad_index = self.tokenizer.vocab['[PAD]']
402
+ attn_mask = text.ne(pad_index).type(self.dtype)
403
+ x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
404
+ return x[:, 0, :] @ self.text_projection
405
+
406
+ def forward(self, image, text, mask_ratio=0):
407
+ assert image is not None or text is not None, "text and image cannot both be None!"
408
+
409
+ if image is None:
410
+ return self.encode_text(text)
411
+ elif text is None:
412
+ return self.encode_image(image, mask_ratio)
413
+ image_features = self.encode_image(image, mask_ratio)
414
+ text_features = self.encode_text(text)
415
+
416
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
417
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
418
+
419
+ return image_features, text_features, self.logit_scale.exp()
420
+
421
+ def get_similarity(self, image, text):
422
+ image_features = self.encode_image(image)
423
+ text_features = self.encode_text(text)
424
+
425
+ # normalized features
426
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
427
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
428
+
429
+ # cosine similarity as logits
430
+ logit_scale = self.logit_scale.exp()
431
+ logits_per_image = logit_scale * image_features @ text_features.t()
432
+ logits_per_text = logits_per_image.t()
433
+
434
+ # shape = [global_batch_size, global_batch_size]
435
+ return logits_per_image, logits_per_text
436
+
437
+ class CLIPWithTwoTextEncoder(nn.Module):
438
+ def __init__(self,
439
+ embed_dim: int,
440
+ # vision
441
+ image_resolution: int,
442
+ vision_layers: Union[Tuple[int, int, int, int], int],
443
+ vision_width: int,
444
+ vision_patch_size: int,
445
+ # text
446
+ vocab_size: int,
447
+ text_attention_probs_dropout_prob: float,
448
+ text_hidden_act: str,
449
+ text_hidden_dropout_prob: float,
450
+ text_hidden_size: int,
451
+ text_initializer_range: float,
452
+ text_intermediate_size: int,
453
+ text_max_position_embeddings: int,
454
+ text_num_attention_heads: int,
455
+ text_num_hidden_layers: int,
456
+ text_type_vocab_size: int,
457
+ tokenizer = _tokenizer,
458
+ # vision head width, added this param for ViT-H
459
+ vision_head_width: int = 64,
460
+ use_flash_attention: bool = False,
461
+ openai_clip_path: str = "/group/30042/kunyi/CLIP/clip-vit-large-patch14/",
462
+ ):
463
+ super().__init__()
464
+
465
+ if isinstance(vision_layers, (tuple, list)):
466
+ vision_heads = vision_width * 32 // vision_head_width
467
+ self.visual = ModifiedResNet(
468
+ layers=vision_layers,
469
+ output_dim=embed_dim,
470
+ heads=vision_heads,
471
+ input_resolution=image_resolution,
472
+ width=vision_width
473
+ )
474
+ else:
475
+ vision_heads = vision_width // vision_head_width
476
+ self.visual = VisualTransformer(
477
+ input_resolution=image_resolution,
478
+ patch_size=vision_patch_size,
479
+ width=vision_width,
480
+ layers=vision_layers,
481
+ heads=vision_heads,
482
+ output_dim=embed_dim,
483
+ use_flash_attention=use_flash_attention
484
+ )
485
+
486
+ self.bert_config = BertConfig(
487
+ vocab_size_or_config_json_file=vocab_size,
488
+ hidden_size=text_hidden_size,
489
+ num_hidden_layers=text_num_hidden_layers,
490
+ num_attention_heads=text_num_attention_heads,
491
+ intermediate_size=text_intermediate_size,
492
+ hidden_act=text_hidden_act,
493
+ hidden_dropout_prob=text_hidden_dropout_prob,
494
+ attention_probs_dropout_prob=text_attention_probs_dropout_prob,
495
+ max_position_embeddings=text_max_position_embeddings,
496
+ type_vocab_size=text_type_vocab_size,
497
+ initializer_range=text_initializer_range,
498
+ layer_norm_eps=1e-12,
499
+ use_flash_attention=use_flash_attention
500
+ )
501
+ self.bert = BertModel(self.bert_config)
502
+
503
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
504
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
505
+
506
+ self.tokenizer = tokenizer
507
+
508
+ print('loading openai clip text encoder')
509
+ self.openai_clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(openai_clip_path)
510
+
511
+ self.initialize_parameters()
512
+
513
+
514
+ def initialize_parameters(self):
515
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
516
+
517
+ if isinstance(self.visual, ModifiedResNet):
518
+ if self.visual.attnpool is not None:
519
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
520
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
521
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
522
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
523
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
524
+
525
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
526
+ for name, param in resnet_block.named_parameters():
527
+ if name.endswith("bn3.weight"):
528
+ nn.init.zeros_(param)
529
+
530
+ if self.text_projection is not None:
531
+ nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
532
+
533
+ @torch.jit.ignore
534
+ def set_grad_checkpointing(self, enable=True):
535
+ self.visual.set_grad_checkpointing(enable)
536
+ self.bert.set_grad_checkpointing(enable)
537
+
538
+ @property
539
+ def dtype(self):
540
+ return self.visual.conv1.weight.dtype
541
+
542
+ def encode_image(self, image, mask_ratio=0):
543
+ if isinstance(self.visual, ModifiedResNet):
544
+ # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
545
+ return self.visual(image.type(self.dtype))
546
+ return self.visual(image.type(self.dtype), mask_ratio)
547
+
548
+ def encode_text(self, text):
549
+ pad_index = self.tokenizer.vocab['[PAD]']
550
+ attn_mask = text.ne(pad_index).type(self.dtype)
551
+ x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
552
+ return x[:, 0, :] @ self.text_projection
553
+
554
+ def encode_text_ENG(self, text):
555
+ text_emb = self.openai_clip_text_encoder(text).text_embeds
556
+ return text_emb
557
+
558
+ def forward(self, image, text, is_ENG=False, mask_ratio=0):
559
+ assert image is not None or text is not None, "text and image cannot both be None!"
560
+
561
+ if image is None:
562
+ if not is_ENG:
563
+ return self.encode_text(text)
564
+ else:
565
+ return self.encode_text_ENG(text)
566
+ elif text is None:
567
+ return self.encode_image(image, mask_ratio)
568
+ image_features = self.encode_image(image, mask_ratio)
569
+
570
+ if not is_ENG:
571
+ text_features = self.encode_text(text)
572
+ else:
573
+ text_features = self.encode_text_ENG(text)
574
+
575
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
576
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
577
+
578
+ return image_features, text_features, self.logit_scale.exp()
579
+
580
+ def get_similarity(self, image, text):
581
+ image_features = self.encode_image(image)
582
+ text_features = self.encode_text(text)
583
+
584
+ # normalized features
585
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
586
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
587
+
588
+ # cosine similarity as logits
589
+ logit_scale = self.logit_scale.exp()
590
+ logits_per_image = logit_scale * image_features @ text_features.t()
591
+ logits_per_text = logits_per_image.t()
592
+
593
+ # shape = [global_batch_size, global_batch_size]
594
+ return logits_per_image, logits_per_text
595
+
596
+ class CLIP4SD(nn.Module):
597
+ def __init__(self,
598
+ embed_dim: int,
599
+ # vision
600
+ image_resolution: int,
601
+ vision_layers: Union[Tuple[int, int, int, int], int],
602
+ vision_width: int,
603
+ vision_patch_size: int,
604
+ # text
605
+ vocab_size: int,
606
+ text_attention_probs_dropout_prob: float,
607
+ text_hidden_act: str,
608
+ text_hidden_dropout_prob: float,
609
+ text_hidden_size: int,
610
+ text_initializer_range: float,
611
+ text_intermediate_size: int,
612
+ text_max_position_embeddings: int,
613
+ text_num_attention_heads: int,
614
+ text_num_hidden_layers: int,
615
+ text_type_vocab_size: int,
616
+ tokenizer = _tokenizer,
617
+ # vision head width, added this param for ViT-H
618
+ vision_head_width: int = 64,
619
+ use_flash_attention: bool = False,
620
+ ):
621
+ super().__init__()
622
+
623
+ if isinstance(vision_layers, (tuple, list)):
624
+ vision_heads = vision_width * 32 // vision_head_width
625
+ self.visual = ModifiedResNet(
626
+ layers=vision_layers,
627
+ output_dim=embed_dim,
628
+ heads=vision_heads,
629
+ input_resolution=image_resolution,
630
+ width=vision_width
631
+ )
632
+ else:
633
+ vision_heads = vision_width // vision_head_width
634
+ self.visual = VisualTransformer(
635
+ input_resolution=image_resolution,
636
+ patch_size=vision_patch_size,
637
+ width=vision_width,
638
+ layers=vision_layers,
639
+ heads=vision_heads,
640
+ output_dim=embed_dim,
641
+ use_flash_attention=use_flash_attention
642
+ )
643
+
644
+ self.bert_config = BertConfig(
645
+ vocab_size_or_config_json_file=vocab_size,
646
+ hidden_size=text_hidden_size,
647
+ num_hidden_layers=text_num_hidden_layers,
648
+ num_attention_heads=text_num_attention_heads,
649
+ intermediate_size=text_intermediate_size,
650
+ hidden_act=text_hidden_act,
651
+ hidden_dropout_prob=text_hidden_dropout_prob,
652
+ attention_probs_dropout_prob=text_attention_probs_dropout_prob,
653
+ max_position_embeddings=text_max_position_embeddings,
654
+ type_vocab_size=text_type_vocab_size,
655
+ initializer_range=text_initializer_range,
656
+ layer_norm_eps=1e-12,
657
+ use_flash_attention=use_flash_attention
658
+ )
659
+ self.bert = BertModel(self.bert_config)
660
+
661
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, embed_dim))
662
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
663
+
664
+ self.tokenizer = tokenizer
665
+ self.ln_final = LayerNorm(text_hidden_size)
666
+
667
+ self.initialize_parameters()
668
+
669
+ def initialize_parameters(self):
670
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
671
+
672
+ if isinstance(self.visual, ModifiedResNet):
673
+ if self.visual.attnpool is not None:
674
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
675
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
676
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
677
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
678
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
679
+
680
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
681
+ for name, param in resnet_block.named_parameters():
682
+ if name.endswith("bn3.weight"):
683
+ nn.init.zeros_(param)
684
+
685
+ if self.text_projection is not None:
686
+ nn.init.normal_(self.text_projection, std=self.bert_config.hidden_size ** -0.5)
687
+
688
+ @torch.jit.ignore
689
+ def set_grad_checkpointing(self, enable=True):
690
+ self.visual.set_grad_checkpointing(enable)
691
+ self.bert.set_grad_checkpointing(enable)
692
+
693
+ @property
694
+ def dtype(self):
695
+ return self.visual.conv1.weight.dtype
696
+
697
+ def encode_image(self, image, mask_ratio=0):
698
+ if isinstance(self.visual, ModifiedResNet):
699
+ # mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
700
+ return self.visual(image.type(self.dtype))
701
+ return self.visual(image.type(self.dtype), mask_ratio)
702
+
703
+ # def encode_text(self, text):
704
+ # pad_index = self.tokenizer.vocab['[PAD]']
705
+ # attn_mask = text.ne(pad_index).type(self.dtype)
706
+ # x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
707
+ # return x[:, 0, :] @ self.text_projection
708
+ def encode_text(self, text):
709
+ pad_index = self.tokenizer.vocab['[PAD]']
710
+ attn_mask = text.ne(pad_index).type(self.dtype)
711
+ x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
712
+ x = self.ln_final(x).type(self.dtype)
713
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
714
+ return x
715
+
716
+ def forward(self, image, text, mask_ratio=0):
717
+ assert image is not None or text is not None, "text and image cannot both be None!"
718
+
719
+ if image is None:
720
+ return self.encode_text(text)
721
+ elif text is None:
722
+ return self.encode_image(image)
723
+ image_features = self.encode_image(image, mask_ratio)
724
+ text_features = self.encode_text(text)
725
+
726
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
727
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
728
+
729
+ return image_features, text_features, self.logit_scale.exp()
730
+
731
+ def get_similarity(self, image, text):
732
+ image_features = self.encode_image(image)
733
+ text_features = self.encode_text(text)
734
+
735
+ # normalized features
736
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
737
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
738
+
739
+ # cosine similarity as logits
740
+ logit_scale = self.logit_scale.exp()
741
+ logits_per_image = logit_scale * image_features @ text_features.t()
742
+ logits_per_text = logits_per_image.t()
743
+
744
+ # shape = [global_batch_size, global_batch_size]
745
+ return logits_per_image, logits_per_text
746
+
747
+ def convert_models_to_fp32(model):
748
+ for p in model.parameters():
749
+ p.data = p.data.float()
750
+ if p.grad:
751
+ p.grad.data = p.grad.data.float()
752
+
753
+
754
+ def convert_weights(model: nn.Module):
755
+ """Convert applicable model parameters to fp16"""
756
+
757
+ def _convert_weights_to_fp16(l):
758
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
759
+ l.weight.data = l.weight.data.half()
760
+ if l.bias is not None:
761
+ l.bias.data = l.bias.data.half()
762
+
763
+ if isinstance(l, nn.MultiheadAttention):
764
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
765
+ tensor = getattr(l, attr)
766
+ if tensor is not None:
767
+ tensor.data = tensor.data.half()
768
+
769
+ if isinstance(l, BertModel):
770
+ l.to(torch.half)
771
+
772
+ for name in ["text_projection", "proj"]:
773
+ try:
774
+ if hasattr(l, name):
775
+ attr = getattr(l, name)
776
+ if attr is not None:
777
+ attr.data = attr.data.half()
778
+ except:
779
+ print('name', name)
780
+
781
+ model.apply(_convert_weights_to_fp16)
782
+
783
+
784
+ def restore_model(model, clip_state_dict: dict, bert_state_dict: dict, use_flash_attention: bool):
785
+ merged_state_dict = {}
786
+
787
+ # use clip_state_dict to initialize the image encoder & logit scale
788
+ if clip_state_dict is not None:
789
+ for k, v in clip_state_dict.items():
790
+ if k.startswith("visual") or k == "logit_scale":
791
+ merged_state_dict[k] = v
792
+
793
+ # use bert_state_dict to initialize the text encoder
794
+ if bert_state_dict is not None:
795
+ for k, v in bert_state_dict.items():
796
+ if k.startswith("bert") and "bert.pooler" not in k:
797
+ merged_state_dict[k] = v
798
+
799
+ # adapt flash attention
800
+ if use_flash_attention:
801
+ merged_state_dict = convert_state_dict(merged_state_dict)
802
+
803
+ convert_weights(model)
804
+ resize_pos_embed(merged_state_dict, model)
805
+ model.load_state_dict(merged_state_dict, strict=False)
806
+ return model.eval()
807
+
808
+
809
+ def convert_state_dict(state_dict):
810
+ """Adapt to Flash Attention"""
811
+ if not state_dict:
812
+ return state_dict
813
+
814
+ prefix = 'module.' if list(state_dict.keys())[0].startswith('module') else ''
815
+
816
+ if f'{prefix}visual.transformer.resblocks.0.attn.in_proj_weight' in state_dict:
817
+ for k in list(state_dict.keys()):
818
+ if 'attn.in_proj_weight' in k:
819
+ state_dict[k.replace('attn.in_proj_weight', 'attn.Wqkv.weight')] = state_dict.pop(k)
820
+ elif 'attn.in_proj_bias' in k:
821
+ state_dict[k.replace('attn.in_proj_bias', 'attn.Wqkv.bias')] = state_dict.pop(k)
822
+ elif f'{prefix}visual.transformer.resblocks.0.attn.Wqkv.weight' in state_dict:
823
+ for k in list(state_dict.keys()):
824
+ if 'attn.Wqkv.weight' in k:
825
+ state_dict[k.replace('attn.Wqkv.weight', 'attn.in_proj_weight')] = state_dict.pop(k)
826
+ elif 'attn.Wqkv.bias' in k:
827
+ state_dict[k.replace('attn.Wqkv.bias', 'attn.in_proj_bias')] = state_dict.pop(k)
828
+
829
+ if f'{prefix}bert.encoder.layer.0.attention.self.query.weight' in state_dict:
830
+ i = 0
831
+ while f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight' in state_dict:
832
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'] = torch.cat(
833
+ (state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'),
834
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'),
835
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'))
836
+ )
837
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'] = torch.cat(
838
+ (state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'),
839
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'),
840
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'))
841
+ )
842
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight'] = \
843
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight')
844
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.bias'] = \
845
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias')
846
+ i += 1
847
+ elif f'{prefix}bert.encoder.layer.0.attention.self.Wqkv.weight' in state_dict:
848
+ i = 0
849
+ while f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight' in state_dict:
850
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'], \
851
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'], \
852
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'] = \
853
+ torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'), chunks=3)
854
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'], \
855
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'], \
856
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'] = \
857
+ torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'), chunks=3)
858
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight'] = \
859
+ state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight')
860
+ state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias'] = \
861
+ state_dict.pop(f'module.bert.encoder.layer.{i}.attention.self.out_proj.bias')
862
+ i += 1
863
+
864
+ return state_dict
865
+
866
+
867
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1, prefix=""):
868
+ # Rescale the grid of position embeddings when loading from state_dict
869
+ old_pos_embed = state_dict.get(prefix + 'visual.positional_embedding', None)
870
+ model = model.module if hasattr(model, 'module') else model
871
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
872
+ return
873
+ grid_size = to_2tuple(model.visual.grid_size)
874
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
875
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
876
+ if new_seq_len == old_pos_embed.shape[0]:
877
+ return
878
+
879
+ if extra_tokens:
880
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
881
+ else:
882
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
883
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
884
+
885
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
886
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
887
+ pos_emb_img = F.interpolate(
888
+ pos_emb_img,
889
+ size=grid_size,
890
+ mode=interpolation,
891
+ align_corners=True,
892
+ )
893
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
894
+ if pos_emb_tok is not None:
895
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
896
+ else:
897
+ new_pos_embed = pos_emb_img
898
+ state_dict[prefix + 'visual.positional_embedding'] = new_pos_embed
899
+
900
+
901
+ # From PyTorch internals
902
+ def _ntuple(n):
903
+ def parse(x):
904
+ if isinstance(x, collections.abc.Iterable):
905
+ return x
906
+ return tuple(repeat(x, n))
907
+ return parse
908
+
909
+
910
+ to_1tuple = _ntuple(1)
911
+ to_2tuple = _ntuple(2)
912
+ to_3tuple = _ntuple(3)
913
+ to_4tuple = _ntuple(4)
914
+ to_ntuple = lambda n, x: _ntuple(n)(x)
clip/model_configs/RBT3-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 768,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 3072,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 12,
11
+ "text_num_hidden_layers": 3,
12
+ "text_type_vocab_size": 2
13
+ }
clip/model_configs/RN50.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "image_resolution": 224,
4
+ "vision_layers": "[3,4,6,3]",
5
+ "vision_width": 64,
6
+ "vision_patch_size": null
7
+ }
clip/model_configs/RoBERTa-wwm-ext-base-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 768,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 3072,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 12,
11
+ "text_num_hidden_layers": 12,
12
+ "text_type_vocab_size": 2
13
+ }
clip/model_configs/RoBERTa-wwm-ext-large-chinese.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 21128,
3
+ "text_attention_probs_dropout_prob": 0.1,
4
+ "text_hidden_act": "gelu",
5
+ "text_hidden_dropout_prob": 0.1,
6
+ "text_hidden_size": 1024,
7
+ "text_initializer_range": 0.02,
8
+ "text_intermediate_size": 4096,
9
+ "text_max_position_embeddings": 512,
10
+ "text_num_attention_heads": 16,
11
+ "text_num_hidden_layers": 24,
12
+ "text_type_vocab_size": 2
13
+ }
clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "image_resolution": 224,
4
+ "vision_layers": 12,
5
+ "vision_width": 768,
6
+ "vision_patch_size": 16
7
+ }
clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "image_resolution": 224,
4
+ "vision_layers": 12,
5
+ "vision_width": 768,
6
+ "vision_patch_size": 32
7
+ }
clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "image_resolution": 224,
4
+ "vision_layers": 32,
5
+ "vision_width": 1280,
6
+ "vision_head_width": 80,
7
+ "vision_patch_size": 14
8
+ }
clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "image_resolution": 336,
4
+ "vision_layers": 24,
5
+ "vision_width": 1024,
6
+ "vision_patch_size": 14
7
+ }
clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "image_resolution": 224,
4
+ "vision_layers": 24,
5
+ "vision_width": 1024,
6
+ "vision_patch_size": 14
7
+ }
clip/modeling_bert.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ from io import open
26
+
27
+ import torch
28
+ from torch import nn
29
+ from torch.utils.checkpoint import checkpoint
30
+
31
+ import importlib.util
32
+ if importlib.util.find_spec('flash_attn'):
33
+ FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA
34
+
35
+ from .configuration_bert import BertConfig
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ def gelu(x):
40
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
41
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
42
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
43
+ Also see https://arxiv.org/abs/1606.08415
44
+ """
45
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
46
+
47
+ def gelu_new(x):
48
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
49
+ Also see https://arxiv.org/abs/1606.08415
50
+ """
51
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
52
+
53
+ def swish(x):
54
+ return x * torch.sigmoid(x)
55
+
56
+
57
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
58
+
59
+
60
+ BertLayerNorm = torch.nn.LayerNorm
61
+
62
+ class BertEmbeddings(nn.Module):
63
+ """Construct the embeddings from word, position and token_type embeddings.
64
+ """
65
+ def __init__(self, config):
66
+ super(BertEmbeddings, self).__init__()
67
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
68
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
69
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
70
+
71
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
72
+ # any TensorFlow checkpoint file
73
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
74
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
75
+
76
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
77
+ seq_length = input_ids.size(1)
78
+ if position_ids is None:
79
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
80
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
81
+ if token_type_ids is None:
82
+ token_type_ids = torch.zeros_like(input_ids)
83
+
84
+ words_embeddings = self.word_embeddings(input_ids)
85
+ position_embeddings = self.position_embeddings(position_ids)
86
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
87
+
88
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
89
+ embeddings = self.LayerNorm(embeddings)
90
+ embeddings = self.dropout(embeddings)
91
+ return embeddings
92
+
93
+
94
+ class BertSelfAttention(nn.Module):
95
+ def __init__(self, config):
96
+ super(BertSelfAttention, self).__init__()
97
+ if config.hidden_size % config.num_attention_heads != 0:
98
+ raise ValueError(
99
+ "The hidden size (%d) is not a multiple of the number of attention "
100
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
101
+ self.output_attentions = config.output_attentions
102
+
103
+ self.num_attention_heads = config.num_attention_heads
104
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
105
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
106
+
107
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
108
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
+
111
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
+
113
+ def transpose_for_scores(self, x):
114
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
115
+ x = x.view(*new_x_shape)
116
+ return x.permute(0, 2, 1, 3)
117
+
118
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
119
+ mixed_query_layer = self.query(hidden_states)
120
+ mixed_key_layer = self.key(hidden_states)
121
+ mixed_value_layer = self.value(hidden_states)
122
+
123
+ query_layer = self.transpose_for_scores(mixed_query_layer)
124
+ key_layer = self.transpose_for_scores(mixed_key_layer)
125
+ value_layer = self.transpose_for_scores(mixed_value_layer)
126
+
127
+ # Take the dot product between "query" and "key" to get the raw attention scores.
128
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
129
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
130
+ if attention_mask is not None:
131
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
132
+ attention_scores = attention_scores + attention_mask
133
+
134
+ # Normalize the attention scores to probabilities.
135
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
136
+
137
+ # This is actually dropping out entire tokens to attend to, which might
138
+ # seem a bit unusual, but is taken from the original Transformer paper.
139
+ attention_probs = self.dropout(attention_probs)
140
+
141
+ # Mask heads if we want to
142
+ if head_mask is not None:
143
+ attention_probs = attention_probs * head_mask
144
+
145
+ context_layer = torch.matmul(attention_probs, value_layer)
146
+
147
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
148
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
149
+ context_layer = context_layer.view(*new_context_layer_shape)
150
+
151
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
152
+ return outputs
153
+
154
+
155
+ class BertSelfOutput(nn.Module):
156
+ def __init__(self, config):
157
+ super(BertSelfOutput, self).__init__()
158
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
159
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
160
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
161
+
162
+ def forward(self, hidden_states, input_tensor):
163
+ hidden_states = self.dense(hidden_states)
164
+ hidden_states = self.dropout(hidden_states)
165
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
166
+ return hidden_states
167
+
168
+
169
+ class BertAttention(nn.Module):
170
+ def __init__(self, config):
171
+ super(BertAttention, self).__init__()
172
+ self.self = BertSelfAttention(config) if not config.use_flash_attention else FlashMHA(config.hidden_size, config.num_attention_heads)
173
+ self.output = BertSelfOutput(config) if not config.use_flash_attention else BertSelfOutputForFlashAttention(config)
174
+ self.pruned_heads = set()
175
+ self.config = config
176
+
177
+ def forward(self, input_tensor, attention_mask=None, head_mask=None):
178
+ if not self.config.use_flash_attention:
179
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
180
+ else:
181
+ key_padding_mask = self.get_key_padding_mask(attention_mask)
182
+ self_outputs = self.self(input_tensor, key_padding_mask=key_padding_mask)
183
+ attention_output = self.output(self_outputs[0], input_tensor)
184
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
185
+ return outputs
186
+
187
+ def get_key_padding_mask(self, attention_mask):
188
+ # key_padding_mask: bool tensor of shape (batch, seqlen)
189
+ return attention_mask.squeeze(1).squeeze(1) == 0
190
+
191
+
192
+ class BertIntermediate(nn.Module):
193
+ def __init__(self, config):
194
+ super(BertIntermediate, self).__init__()
195
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
196
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
197
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
198
+ else:
199
+ self.intermediate_act_fn = config.hidden_act
200
+
201
+ def forward(self, hidden_states):
202
+ hidden_states = self.dense(hidden_states)
203
+ hidden_states = self.intermediate_act_fn(hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class BertOutput(nn.Module):
208
+ def __init__(self, config):
209
+ super(BertOutput, self).__init__()
210
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
211
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
212
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
213
+
214
+ def forward(self, hidden_states, input_tensor):
215
+ hidden_states = self.dense(hidden_states)
216
+ hidden_states = self.dropout(hidden_states)
217
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
218
+ return hidden_states
219
+
220
+
221
+ class BertSelfOutputForFlashAttention(nn.Module): # remove linear layer
222
+ def __init__(self, config):
223
+ super(BertSelfOutputForFlashAttention, self).__init__()
224
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
+
227
+ def forward(self, hidden_states, input_tensor):
228
+ hidden_states = self.dropout(hidden_states)
229
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
230
+ return hidden_states
231
+
232
+
233
+ class BertLayer(nn.Module):
234
+ def __init__(self, config):
235
+ super(BertLayer, self).__init__()
236
+ self.attention = BertAttention(config)
237
+ self.intermediate = BertIntermediate(config)
238
+ self.output = BertOutput(config)
239
+
240
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
241
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
242
+ attention_output = attention_outputs[0]
243
+ intermediate_output = self.intermediate(attention_output)
244
+ layer_output = self.output(intermediate_output, attention_output)
245
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
246
+ if len(outputs) == 1:
247
+ return outputs[0]
248
+ return outputs
249
+
250
+
251
+ class BertEncoder(nn.Module):
252
+ def __init__(self, config):
253
+ super(BertEncoder, self).__init__()
254
+ self.output_attentions = config.output_attentions
255
+ self.output_hidden_states = config.output_hidden_states
256
+ self.grad_checkpointing = False
257
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
258
+
259
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
260
+ all_hidden_states = ()
261
+ all_attentions = ()
262
+ for i, layer_module in enumerate(self.layer):
263
+ if self.output_hidden_states:
264
+ all_hidden_states = all_hidden_states + (hidden_states,)
265
+
266
+ if self.grad_checkpointing and not torch.jit.is_scripting():
267
+ layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i])
268
+ else:
269
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
270
+ if not isinstance(layer_outputs, tuple):
271
+ layer_outputs = (layer_outputs, )
272
+ hidden_states = layer_outputs[0]
273
+
274
+ if self.output_attentions:
275
+ all_attentions = all_attentions + (layer_outputs[1],)
276
+
277
+ # Add last layer
278
+ if self.output_hidden_states:
279
+ all_hidden_states = all_hidden_states + (hidden_states,)
280
+
281
+ outputs = (hidden_states,)
282
+ if self.output_hidden_states:
283
+ outputs = outputs + (all_hidden_states,)
284
+ if self.output_attentions:
285
+ outputs = outputs + (all_attentions,)
286
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
287
+
288
+
289
+ class BertPooler(nn.Module):
290
+ def __init__(self, config):
291
+ super(BertPooler, self).__init__()
292
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
293
+ self.activation = nn.Tanh()
294
+
295
+ def forward(self, hidden_states):
296
+ # We "pool" the model by simply taking the hidden state corresponding
297
+ # to the first token.
298
+ first_token_tensor = hidden_states[:, 0]
299
+ pooled_output = self.dense(first_token_tensor)
300
+ pooled_output = self.activation(pooled_output)
301
+ return pooled_output
302
+
303
+
304
+ class BertPredictionHeadTransform(nn.Module):
305
+ def __init__(self, config):
306
+ super(BertPredictionHeadTransform, self).__init__()
307
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
308
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
309
+ self.transform_act_fn = ACT2FN[config.hidden_act]
310
+ else:
311
+ self.transform_act_fn = config.hidden_act
312
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
313
+
314
+ def forward(self, hidden_states):
315
+ hidden_states = self.dense(hidden_states)
316
+ hidden_states = self.transform_act_fn(hidden_states)
317
+ hidden_states = self.LayerNorm(hidden_states)
318
+ return hidden_states
319
+
320
+
321
+ class BertLMPredictionHead(nn.Module):
322
+ def __init__(self, config):
323
+ super(BertLMPredictionHead, self).__init__()
324
+ self.transform = BertPredictionHeadTransform(config)
325
+
326
+ # The output weights are the same as the input embeddings, but there is
327
+ # an output-only bias for each token.
328
+ self.decoder = nn.Linear(config.hidden_size,
329
+ config.vocab_size,
330
+ bias=False)
331
+
332
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
333
+
334
+ def forward(self, hidden_states):
335
+ hidden_states = self.transform(hidden_states)
336
+ hidden_states = self.decoder(hidden_states) + self.bias
337
+ return hidden_states
338
+
339
+
340
+ class BertOnlyMLMHead(nn.Module):
341
+ def __init__(self, config):
342
+ super(BertOnlyMLMHead, self).__init__()
343
+ self.predictions = BertLMPredictionHead(config)
344
+
345
+ def forward(self, sequence_output):
346
+ prediction_scores = self.predictions(sequence_output)
347
+ return prediction_scores
348
+
349
+
350
+ class BertOnlyNSPHead(nn.Module):
351
+ def __init__(self, config):
352
+ super(BertOnlyNSPHead, self).__init__()
353
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
354
+
355
+ def forward(self, pooled_output):
356
+ seq_relationship_score = self.seq_relationship(pooled_output)
357
+ return seq_relationship_score
358
+
359
+
360
+ class BertPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super(BertPreTrainingHeads, self).__init__()
363
+ self.predictions = BertLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class BertPreTrainedModel(nn.Module):
373
+ config_class = BertConfig
374
+ base_model_prefix = "bert"
375
+
376
+ def __init__(self, config):
377
+ super(BertPreTrainedModel, self).__init__()
378
+ self.config = config
379
+
380
+ def _init_weights(self, module):
381
+ """ Initialize the weights """
382
+ if isinstance(module, (nn.Linear, nn.Embedding)):
383
+ # Slightly different from the TF version which uses truncated_normal for initialization
384
+ # cf https://github.com/pytorch/pytorch/pull/5617
385
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
386
+ elif isinstance(module, BertLayerNorm):
387
+ module.bias.data.zero_()
388
+ module.weight.data.fill_(1.0)
389
+ if isinstance(module, nn.Linear) and module.bias is not None:
390
+ module.bias.data.zero_()
391
+
392
+
393
+ class BertModel(BertPreTrainedModel):
394
+ r"""
395
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
396
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
397
+ Sequence of hidden-states at the output of the last layer of the model.
398
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
399
+ Last layer hidden-state of the first token of the sequence (classification token)
400
+ further processed by a Linear layer and a Tanh activation function. The Linear
401
+ layer weights are trained from the next sentence prediction (classification)
402
+ objective during Bert pretraining. This output is usually *not* a good summary
403
+ of the semantic content of the input, you're often better with averaging or pooling
404
+ the sequence of hidden-states for the whole input sequence.
405
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
406
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
407
+ of shape ``(batch_size, sequence_length, hidden_size)``:
408
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
409
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
410
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
411
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
412
+
413
+ Examples::
414
+
415
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
416
+ model = BertModel.from_pretrained('bert-base-uncased')
417
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
418
+ outputs = model(input_ids)
419
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
420
+
421
+ """
422
+ def __init__(self, config):
423
+ super(BertModel, self).__init__(config)
424
+
425
+ self.embeddings = BertEmbeddings(config)
426
+ self.encoder = BertEncoder(config)
427
+ # self.pooler = BertPooler(config)
428
+
429
+ self.apply(self._init_weights)
430
+
431
+ @torch.jit.ignore
432
+ def set_grad_checkpointing(self, enable=True):
433
+ if enable:
434
+ assert not self.config.output_attentions, \
435
+ "Grad checkpointing is currently conflict with output_attentions for BertEncoder, \
436
+ please set it to False in BertConfig"
437
+ self.encoder.grad_checkpointing = enable
438
+
439
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
440
+ if attention_mask is None:
441
+ attention_mask = torch.ones_like(input_ids)
442
+ if token_type_ids is None:
443
+ token_type_ids = torch.zeros_like(input_ids)
444
+
445
+ # We create a 3D attention mask from a 2D tensor mask.
446
+ # Sizes are [batch_size, 1, 1, to_seq_length]
447
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
448
+ # this attention mask is more simple than the triangular masking of causal attention
449
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
450
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
451
+
452
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
453
+ # masked positions, this operation will create a tensor which is 0.0 for
454
+ # positions we want to attend and -10000.0 for masked positions.
455
+ # Since we are adding it to the raw scores before the softmax, this is
456
+ # effectively the same as removing these entirely.
457
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
458
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
459
+
460
+ # Prepare head mask if needed
461
+ # 1.0 in head_mask indicate we keep the head
462
+ # attention_probs has shape bsz x n_heads x N x N
463
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
464
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
465
+ if head_mask is not None:
466
+ if head_mask.dim() == 1:
467
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
468
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
469
+ elif head_mask.dim() == 2:
470
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
471
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
472
+ else:
473
+ head_mask = [None] * self.config.num_hidden_layers
474
+
475
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
476
+ encoder_outputs = self.encoder(embedding_output,
477
+ extended_attention_mask,
478
+ head_mask=head_mask)
479
+ sequence_output = encoder_outputs[0]
480
+ # pooled_output = self.pooler(sequence_output)
481
+ pooled_output = None
482
+
483
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
484
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
clip/utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from https://github.com/openai/CLIP
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Union, List
7
+ import urllib
8
+
9
+ import torch
10
+ from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode
11
+ from tqdm import tqdm
12
+
13
+ from clip import _tokenizer
14
+ from clip.model import convert_weights, CLIP, restore_model
15
+
16
+ __all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
17
+
18
+ _MODELS = {
19
+ "ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt",
20
+ "ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt",
21
+ "RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt",
22
+ }
23
+ _MODEL_INFO = {
24
+ "ViT-B-16": {
25
+ "struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
26
+ "input_resolution": 224
27
+ },
28
+ "ViT-L-14": {
29
+ "struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
30
+ "input_resolution": 224
31
+ },
32
+ "RN50": {
33
+ "struct": "RN50@RBT3-chinese",
34
+ "input_resolution": 224
35
+ },
36
+ }
37
+
38
+
39
+ def _download(url: str, root: str):
40
+ os.makedirs(root, exist_ok=True)
41
+ filename = os.path.basename(url)
42
+
43
+ download_target = os.path.join(root, filename)
44
+
45
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
46
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
47
+
48
+ if os.path.isfile(download_target):
49
+ return download_target
50
+
51
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
52
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
53
+ unit_divisor=1024) as loop:
54
+ while True:
55
+ buffer = source.read(8192)
56
+ if not buffer:
57
+ break
58
+
59
+ output.write(buffer)
60
+ loop.update(len(buffer))
61
+
62
+ return download_target
63
+
64
+
65
+ def _convert_image_to_rgb(image):
66
+ return image.convert("RGB")
67
+
68
+
69
+ def available_models() -> List[str]:
70
+ """Returns the names of available CLIP models"""
71
+ return list(_MODELS.keys())
72
+
73
+
74
+ def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
75
+ download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
76
+ if name in _MODELS:
77
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
78
+ model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
79
+ elif os.path.isfile(name):
80
+ assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
81
+ model_path = name
82
+ model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
83
+ else:
84
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
85
+
86
+ with open(model_path, 'rb') as opened_file:
87
+ # loading saved checkpoint
88
+ checkpoint = torch.load(opened_file, map_location="cpu")
89
+
90
+ model = create_model(model_name, checkpoint)
91
+ if str(device) == "cpu":
92
+ model.float()
93
+ else:
94
+ model.to(device)
95
+ return model, image_transform(model_input_resolution)
96
+
97
+
98
+ def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
99
+ bert_path=None, use_flash_attention=False):
100
+ """Load CLIP and BERT model weights
101
+ """
102
+
103
+ bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
104
+ clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
105
+
106
+ restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)
107
+
108
+ if str(device) == "cpu":
109
+ model.float()
110
+ return model
111
+
112
+
113
+ def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor:
114
+ """
115
+ Returns the tokenized representation of given input string(s)
116
+ Parameters
117
+ ----------
118
+ texts : Union[str, List[str]]
119
+ An input string or a list of input strings to tokenize
120
+ context_length : int
121
+ The context length to use; all baseline models use 52 as the context length
122
+ Returns
123
+ -------
124
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
125
+ """
126
+ if isinstance(texts, str):
127
+ texts = [texts]
128
+
129
+ all_tokens = []
130
+ for text in texts:
131
+ all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
132
+ :context_length - 2] + [_tokenizer.vocab['[SEP]']])
133
+
134
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
135
+
136
+ for i, tokens in enumerate(all_tokens):
137
+ assert len(tokens) <= context_length
138
+ result[i, :len(tokens)] = torch.tensor(tokens)
139
+
140
+ return result
141
+
142
+
143
+ def _convert_to_rgb(image):
144
+ return image.convert('RGB')
145
+
146
+
147
+ def image_transform(image_size=224):
148
+ transform = Compose([
149
+ Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
150
+ _convert_to_rgb,
151
+ ToTensor(),
152
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
153
+ ])
154
+ return transform
155
+
156
+
157
+ def create_model(model_name, checkpoint=None):
158
+ vision_model, text_model = model_name.split('@')
159
+ # Initialize the model.
160
+ vision_model_config_file = Path(
161
+ __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
162
+ print('Loading vision model config from', vision_model_config_file)
163
+ assert os.path.exists(vision_model_config_file)
164
+
165
+ text_model_config_file = Path(
166
+ __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
167
+ print('Loading text model config from', text_model_config_file)
168
+ assert os.path.exists(text_model_config_file)
169
+
170
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
171
+ model_info = json.load(fv)
172
+ for k, v in json.load(ft).items():
173
+ model_info[k] = v
174
+ if isinstance(model_info['vision_layers'], str):
175
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
176
+ print('Model info', model_info)
177
+ model = CLIP(**model_info)
178
+ convert_weights(model)
179
+ if checkpoint:
180
+ sd = checkpoint["state_dict"]
181
+ if next(iter(sd.items()))[0].startswith('module'):
182
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
183
+ model.load_state_dict(sd)
184
+ return model
clip/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
eval/cvinw_zeroshot_templates.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script provides templates for manual prompting for zero-shot image classification.
3
+ """
4
+
5
+
6
+ openai_templates = [
7
+ lambda c: f"{c}的照片",
8
+ lambda c: f"质量差的{c}的照片",
9
+ lambda c: f"许多{c}的照片",
10
+ lambda c: f"{c}的雕塑",
11
+ lambda c: f"难以看到{c}的照片",
12
+ lambda c: f"{c}的低分辨率照片",
13
+ lambda c: f"{c}的渲染",
14
+ lambda c: f"涂鸦{c}",
15
+ lambda c: f"{c}的糟糕照片",
16
+ lambda c: f"{c}的裁剪照片",
17
+ lambda c: f"{c}的纹身",
18
+ lambda c: f"{c}的刺绣照片",
19
+ lambda c: f"很难看到{c}的照片",
20
+ lambda c: f"{c}的明亮照片",
21
+ lambda c: f"一张干净的{c}的照片",
22
+ lambda c: f"一张包含{c}的照片",
23
+ lambda c: f"{c}的深色照片",
24
+ lambda c: f"{c}的手绘画",
25
+ lambda c: f"我的{c}的照片",
26
+ lambda c: f"不自然的{c}的照片",
27
+ lambda c: f"一张酷的{c}的照片",
28
+ lambda c: f"{c}的特写照片",
29
+ lambda c: f"{c}的黑白照片",
30
+ lambda c: f"一幅{c}的画",
31
+ lambda c: f"一幅{c}的绘画",
32
+ lambda c: f"一张{c}的像素照片",
33
+ lambda c: f"{c}的雕像",
34
+ lambda c: f"一张{c}的明亮照片",
35
+ lambda c: f"{c}的裁剪照片",
36
+ lambda c: f"人造的{c}的照片",
37
+ lambda c: f"一张关于{c}的照片",
38
+ lambda c: f"损坏的{c}的jpeg照片",
39
+ lambda c: f"{c}的模糊照片",
40
+ lambda c: f"{c}的相片",
41
+ lambda c: f"一张{c}的好照片",
42
+ lambda c: f"{c}的渲染照",
43
+ lambda c: f"视频游戏中的{c}",
44
+ lambda c: f"一张{c}的照片",
45
+ lambda c: f"{c}的涂鸦",
46
+ lambda c: f"{c}的近距离照片",
47
+ lambda c: f"{c}的折纸",
48
+ lambda c: f"{c}在视频游戏中",
49
+ lambda c: f"{c}的草图",
50
+ lambda c: f"{c}的涂鸦照",
51
+ lambda c: f"{c}的折纸形状",
52
+ lambda c: f"低分辨率的{c}的照片",
53
+ lambda c: f"玩具{c}",
54
+ lambda c: f"{c}的副本",
55
+ lambda c: f"{c}的干净的照片",
56
+ lambda c: f"一张大{c}的照片",
57
+ lambda c: f"{c}的重现",
58
+ lambda c: f"一张漂亮的{c}的照片",
59
+ lambda c: f"一张奇怪的{c}的照片",
60
+ lambda c: f"模糊的{c}的照片",
61
+ lambda c: f"卡通{c}",
62
+ lambda c: f"{c}的艺术作品",
63
+ lambda c: f"{c}的素描",
64
+ lambda c: f"刺绣{c}",
65
+ lambda c: f"{c}的像素照",
66
+ lambda c: f"{c}的拍照",
67
+ lambda c: f"{c}的损坏的照片",
68
+ lambda c: f"高质量的{c}的照片",
69
+ lambda c: f"毛绒玩具{c}",
70
+ lambda c: f"漂亮的{c}的照片",
71
+ lambda c: f"小{c}的照片",
72
+ lambda c: f"照片是奇怪的{c}",
73
+ lambda c: f"漫画{c}",
74
+ lambda c: f"{c}的艺术照",
75
+ lambda c: f"{c}的图形",
76
+ lambda c: f"大{c}的照片",
77
+ lambda c: f"黑白的{c}的照片",
78
+ lambda c: f"{c}毛绒玩具",
79
+ lambda c: f"一张{c}的深色照片",
80
+ lambda c: f"{c}的摄影图",
81
+ lambda c: f"{c}的涂鸦照",
82
+ lambda c: f"玩具形状的{c}",
83
+ lambda c: f"拍了{c}的照片",
84
+ lambda c: f"酷酷的{c}的照片",
85
+ lambda c: f"照片里的小{c}",
86
+ lambda c: f"{c}的刺青",
87
+ lambda c: f"{c}的可爱的照片",
88
+ lambda c: f"一张{c}可爱的照片",
89
+ lambda c: f"{c}可爱图片",
90
+ lambda c: f"{c}酷炫图片",
91
+ lambda c: f"一张{c}的酷炫的照片",
92
+ lambda c: f"一张{c}的酷炫图片",
93
+ lambda c: f"这是{c}",
94
+ lambda c: f"{c}的好看照片",
95
+ lambda c: f"一张{c}的好看的图片",
96
+ lambda c: f"{c}的好看图片",
97
+ lambda c: f"{c}的照片。",
98
+ lambda c: f"质量差的{c}的照片。",
99
+ lambda c: f"许多{c}的照片。",
100
+ lambda c: f"{c}的雕塑。",
101
+ lambda c: f"难以看到{c}的照片。",
102
+ lambda c: f"{c}的低分辨率照片。",
103
+ lambda c: f"{c}的渲染。",
104
+ lambda c: f"涂鸦{c}。",
105
+ lambda c: f"{c}的糟糕照片。",
106
+ lambda c: f"{c}的裁剪照片。",
107
+ lambda c: f"{c}的纹身。",
108
+ lambda c: f"{c}的刺绣照片。",
109
+ lambda c: f"很难看到{c}的照片。",
110
+ lambda c: f"{c}的明亮照片。",
111
+ lambda c: f"一张干净的{c}的照片。",
112
+ lambda c: f"一张包含{c}的照片。",
113
+ lambda c: f"{c}的深色照片。",
114
+ lambda c: f"{c}的手绘画。",
115
+ lambda c: f"我的{c}的照片。",
116
+ lambda c: f"不自然的{c}的照片。",
117
+ lambda c: f"一张酷的{c}的照片。",
118
+ lambda c: f"{c}的特写照片。",
119
+ lambda c: f"{c}的黑白照片。",
120
+ lambda c: f"一幅{c}的画。",
121
+ lambda c: f"一幅{c}的绘画。",
122
+ lambda c: f"一张{c}的像素照片。",
123
+ lambda c: f"{c}的雕像。",
124
+ lambda c: f"一张{c}的明亮照片。",
125
+ lambda c: f"{c}的裁剪照片。",
126
+ lambda c: f"人造的{c}的照片。",
127
+ lambda c: f"一张关于{c}的照片。",
128
+ lambda c: f"损坏的{c}的jpeg照片。",
129
+ lambda c: f"{c}的模糊照片。",
130
+ lambda c: f"{c}的相片。",
131
+ lambda c: f"一张{c}的好照片。",
132
+ lambda c: f"{c}的渲染照。",
133
+ lambda c: f"视频游戏中的{c}。",
134
+ lambda c: f"一张{c}的照片。",
135
+ lambda c: f"{c}的涂鸦。",
136
+ lambda c: f"{c}的近距离照片。",
137
+ lambda c: f"{c}的折纸。",
138
+ lambda c: f"{c}在视频游戏中。",
139
+ lambda c: f"{c}的草图。",
140
+ lambda c: f"{c}的涂鸦照。",
141
+ lambda c: f"{c}的折纸形状。",
142
+ lambda c: f"低分辨率的{c}的照片。",
143
+ lambda c: f"玩具{c}。",
144
+ lambda c: f"{c}的副本。",
145
+ lambda c: f"{c}的干净的照片。",
146
+ lambda c: f"一张大{c}的照片。",
147
+ lambda c: f"{c}的重现。",
148
+ lambda c: f"一张漂亮的{c}的照片。",
149
+ lambda c: f"一张奇怪的{c}的照片。",
150
+ lambda c: f"模糊的{c}的照片。",
151
+ lambda c: f"卡通{c}。",
152
+ lambda c: f"{c}的艺术作品。",
153
+ lambda c: f"{c}的素描。",
154
+ lambda c: f"刺绣{c}。",
155
+ lambda c: f"{c}的像素照。",
156
+ lambda c: f"{c}的拍照。",
157
+ lambda c: f"{c}的损坏的照片。",
158
+ lambda c: f"高质量的{c}的照片。",
159
+ lambda c: f"毛绒玩具{c}。",
160
+ lambda c: f"漂亮的{c}的照片。",
161
+ lambda c: f"小{c}的照片。",
162
+ lambda c: f"照片是奇怪的{c}。",
163
+ lambda c: f"漫画{c}。",
164
+ lambda c: f"{c}的艺术照。",
165
+ lambda c: f"{c}的图形。",
166
+ lambda c: f"大{c}的照片。",
167
+ lambda c: f"黑白的{c}的照片。",
168
+ lambda c: f"{c}毛绒玩具。",
169
+ lambda c: f"一张{c}的深色照片。",
170
+ lambda c: f"{c}的摄影图。",
171
+ lambda c: f"{c}的涂鸦照。",
172
+ lambda c: f"玩具形状的{c}。",
173
+ lambda c: f"拍了{c}的照片。",
174
+ lambda c: f"酷酷的{c}的照片。",
175
+ lambda c: f"照片里的小{c}。",
176
+ lambda c: f"{c}的刺青。",
177
+ lambda c: f"{c}的可爱的照片。",
178
+ lambda c: f"一张{c}可爱的照片。",
179
+ lambda c: f"{c}可爱图片。",
180
+ lambda c: f"{c}酷炫图片。",
181
+ lambda c: f"一张{c}的酷炫的照片。",
182
+ lambda c: f"一张{c}的酷炫图片。",
183
+ lambda c: f"这是{c}。",
184
+ lambda c: f"{c}的好看照片。",
185
+ lambda c: f"一张{c}的好看的图片。",
186
+ lambda c: f"{c}的好看图片。",
187
+ lambda c: f"一种叫{c}的花的照片",
188
+ lambda c: f"一种叫{c}的食物的照片",
189
+ lambda c: f"{c}的卫星照片"
190
+ ]
191
+
192
+ normal_templates = [lambda c: f"{c}的图片"]
193
+
194
+ flower_templates = [
195
+ lambda c: f"一种叫{c}的花的照片",
196
+ lambda c: f"一种叫{c}的花卉的照片",
197
+ lambda c: f"一种叫{c}的花朵的照片",
198
+ lambda c: f"一种叫{c}的鲜花的照片",
199
+ lambda c: f"一种叫{c}的花的高清图",
200
+ lambda c: f"一种叫{c}的花卉的高清图",
201
+ lambda c: f"一种叫{c}的花朵的高清图",
202
+ lambda c: f"一种叫{c}的鲜花的高清图",
203
+ lambda c: f"一种叫{c}的花的模糊图片",
204
+ lambda c: f"一种叫{c}的花朵的模糊图片",
205
+ lambda c: f"一种叫{c}的花卉的模糊图片",
206
+ lambda c: f"一种叫{c}的鲜花的模糊图片",
207
+ lambda c: f"一种叫{c}的花的缩放图片",
208
+ lambda c: f"一种叫{c}的花朵的缩放图片",
209
+ lambda c: f"一种叫{c}的花卉的缩放图片",
210
+ lambda c: f"一种叫{c}的鲜花的缩放图片",
211
+ lambda c: f"一种叫{c}的花的摄影图",
212
+ lambda c: f"一种叫{c}的花卉的摄影图",
213
+ lambda c: f"一种叫{c}的花朵的摄影图",
214
+ lambda c: f"一种叫{c}的鲜花的摄影图",
215
+ lambda c: f"一种叫{c}的花的近距离照片",
216
+ lambda c: f"一种叫{c}的花朵的近距离照片",
217
+ lambda c: f"一种叫{c}的花卉的近距离照片",
218
+ lambda c: f"一种叫{c}的鲜花的近距离照片",
219
+ lambda c: f"一种叫{c}的花的裁剪照片",
220
+ lambda c: f"一种叫{c}的花朵的裁剪照片",
221
+ lambda c: f"一种叫{c}的花卉的裁剪照片",
222
+ lambda c: f"一种叫{c}的鲜花的裁剪照片",
223
+ lambda c: f"一种叫{c}的花的好看的图片",
224
+ lambda c: f"一种叫{c}的花朵的好看的图片",
225
+ lambda c: f"一种叫{c}的花卉的好看的图片",
226
+ lambda c: f"一种叫{c}的鲜花的好看的图片",
227
+ ]
228
+
229
+ food_templates = [
230
+ lambda c: f"一种叫{c}的食物的照片",
231
+ lambda c: f"一种叫{c}的美食的照片",
232
+ lambda c: f"一种叫{c}的菜的照片",
233
+ lambda c: f"一种叫{c}的食物的高清图",
234
+ lambda c: f"一种叫{c}的美食的高清图",
235
+ lambda c: f"一种叫{c}的菜的高清图",
236
+ lambda c: f"一种叫{c}的食物的模糊图片",
237
+ lambda c: f"一种叫{c}的美食的模糊图片",
238
+ lambda c: f"一种叫{c}的菜的模糊图片",
239
+ lambda c: f"一种叫{c}的食物的缩放图片",
240
+ lambda c: f"一种叫{c}的美食的缩放图片",
241
+ lambda c: f"一种叫{c}的菜的缩放图片",
242
+ lambda c: f"一种叫{c}的食物的摄影图",
243
+ lambda c: f"一种叫{c}的美食的摄影图",
244
+ lambda c: f"一种叫{c}的菜的摄影图",
245
+ lambda c: f"一种叫{c}的食物的近距离照片",
246
+ lambda c: f"一种叫{c}的美食的近距离照片",
247
+ lambda c: f"一种叫{c}的菜的近距离照片",
248
+ lambda c: f"一种叫{c}的食物的���剪照片",
249
+ lambda c: f"一种叫{c}的美食的裁剪照片",
250
+ lambda c: f"一种叫{c}的菜的裁剪照片",
251
+ ]
252
+
253
+ aircraft_templates = [
254
+ lambda c: f"{c},飞机的照片",
255
+ lambda c: f"{c},飞机的高清图",
256
+ lambda c: f"{c},飞机的模糊图片",
257
+ lambda c: f"{c},飞机的缩放图片",
258
+ lambda c: f"{c},飞机的摄影图",
259
+ lambda c: f"{c},战斗机的照片",
260
+ lambda c: f"{c},战斗机的高清图",
261
+ lambda c: f"{c},战斗机的模糊图片",
262
+ lambda c: f"{c},战斗机的缩放图片",
263
+ lambda c: f"{c},战斗机的摄影图",
264
+ lambda c: f"{c},老飞机的照片",
265
+ lambda c: f"{c},老飞机的高清图",
266
+ lambda c: f"{c},老飞机的模糊图片",
267
+ lambda c: f"{c},老飞机的缩放图片",
268
+ lambda c: f"{c},老飞机的摄影图",
269
+ lambda c: f"{c},大飞机的照片",
270
+ lambda c: f"{c},大飞机的高清图",
271
+ lambda c: f"{c},大飞机的模糊图片",
272
+ lambda c: f"{c},大飞机的缩放图片",
273
+ lambda c: f"{c},大飞机的摄影图",
274
+ lambda c: f"{c},小飞机的照片",
275
+ lambda c: f"{c},小飞机的高清图",
276
+ lambda c: f"{c},小飞机的模糊图片",
277
+ lambda c: f"{c},小飞机的缩放图片",
278
+ lambda c: f"{c},小飞机的摄影图",
279
+ lambda c: f"{c},军用飞机的照片",
280
+ lambda c: f"{c},军用飞机的高清图",
281
+ lambda c: f"{c},军用飞机的模糊图片",
282
+ lambda c: f"{c},军用飞机的缩放图片",
283
+ lambda c: f"{c},军用飞机的摄影图",
284
+ lambda c: f"{c},运输机的照片",
285
+ lambda c: f"{c},运输机的高清图",
286
+ lambda c: f"{c},运输机的模糊图片",
287
+ lambda c: f"{c},运输机的缩放图片",
288
+ lambda c: f"{c},运输机的摄影图",
289
+ lambda c: f"{c},公务机的照片",
290
+ lambda c: f"{c},公务机的高清图",
291
+ lambda c: f"{c},公务机的模糊图片",
292
+ lambda c: f"{c},公务机的缩放图片",
293
+ lambda c: f"{c},公务机的摄影图",
294
+ lambda c: f"{c},客机的照片",
295
+ lambda c: f"{c},客机的高清图",
296
+ lambda c: f"{c},客机的模糊图片",
297
+ lambda c: f"{c},客机的缩放图片",
298
+ lambda c: f"{c},客机的摄影图",
299
+ lambda c: f"{c},喷气机的照片",
300
+ lambda c: f"{c},喷气机的高清图",
301
+ lambda c: f"{c},喷气机的模糊图片",
302
+ lambda c: f"{c},喷气机的缩放图片",
303
+ lambda c: f"{c},喷气机的摄影图",
304
+ lambda c: f"一种叫{c}的飞机的照片",
305
+ lambda c: f"一种叫{c}的飞机的高清图",
306
+ lambda c: f"一种叫{c}的飞机的模糊图片",
307
+ lambda c: f"一种叫{c}的飞机的缩放图片",
308
+ lambda c: f"一种叫{c}的飞机的摄影图",
309
+ lambda c: f"一种叫{c}的战斗机的照片",
310
+ lambda c: f"一种叫{c}的战斗机的高清图",
311
+ lambda c: f"一种叫{c}的战斗机的模糊图片",
312
+ lambda c: f"一种叫{c}的战斗机的缩放图片",
313
+ lambda c: f"一种叫{c}的战斗机的摄影图",
314
+ lambda c: f"一种叫{c}的老飞机的照片",
315
+ lambda c: f"一种叫{c}的老飞机的高清图",
316
+ lambda c: f"一种叫{c}的老飞机的模糊图片",
317
+ lambda c: f"一种叫{c}的老飞机的缩放图片",
318
+ lambda c: f"一种叫{c}的老飞机的摄影图",
319
+ lambda c: f"一种叫{c}的大飞机的照片",
320
+ lambda c: f"一种叫{c}的大飞机的高清图",
321
+ lambda c: f"一种叫{c}的大飞机的模糊图片",
322
+ lambda c: f"一种叫{c}的大飞机的缩放图片",
323
+ lambda c: f"一种叫{c}的大飞机的摄影图",
324
+ lambda c: f"一种叫{c}的小飞机的照片",
325
+ lambda c: f"一种叫{c}的小飞机的高清图",
326
+ lambda c: f"一种叫{c}的小飞机的模糊图片",
327
+ lambda c: f"一种叫{c}的小飞机的缩放图片",
328
+ lambda c: f"一种叫{c}的小飞机的摄影图",
329
+ lambda c: f"一种叫{c}的军用飞机的照片",
330
+ lambda c: f"一种叫{c}的军用飞机的高清图",
331
+ lambda c: f"一种叫{c}的军用飞机的模糊图片",
332
+ lambda c: f"一种叫{c}的军用飞机的缩放图片",
333
+ lambda c: f"一种叫{c}的军用飞机的摄影图",
334
+ lambda c: f"一种叫{c}的运输机的照片",
335
+ lambda c: f"一种叫{c}的运输机的高清图",
336
+ lambda c: f"一种叫{c}的运输机的模糊图片",
337
+ lambda c: f"一种叫{c}的运输机的缩放图片",
338
+ lambda c: f"一种叫{c}的运输机的摄影图",
339
+ lambda c: f"一种叫{c}的公务机的照片",
340
+ lambda c: f"一种叫{c}的公务机的高清图",
341
+ lambda c: f"一种叫{c}的公务机的模糊图片",
342
+ lambda c: f"一种叫{c}的公务机的缩放图片",
343
+ lambda c: f"一种叫{c}的公务机的摄影图",
344
+ lambda c: f"一种叫{c}的客机的照片",
345
+ lambda c: f"一种叫{c}的客机的高清图",
346
+ lambda c: f"一种叫{c}的客机的模糊图片",
347
+ lambda c: f"一种叫{c}的客机的缩放图片",
348
+ lambda c: f"一种叫{c}的客机的摄影图",
349
+ lambda c: f"一种叫{c}的喷气机的照片",
350
+ lambda c: f"���种叫{c}的喷气机的高清图",
351
+ lambda c: f"一种叫{c}的喷气机的模糊图片",
352
+ lambda c: f"一种叫{c}的喷气机的缩放图片",
353
+ lambda c: f"一种叫{c}的喷气机的摄影图",
354
+ ]
355
+
356
+ eurosat_templates = [
357
+ lambda c: f"一张{c}的卫星照片",
358
+ lambda c: f"{c}的卫星照片",
359
+ lambda c: f"一张{c}的高清卫星照片",
360
+ lambda c: f"{c}的高清卫星照片",
361
+ lambda c: f"一张{c}的清晰的卫星照片",
362
+ lambda c: f"{c}的清晰的卫星照片",
363
+ lambda c: f"一张{c}的高质量的卫星照片",
364
+ lambda c: f"{c}的高质量的卫星照片",
365
+ lambda c: f"一张{c}的卫星图",
366
+ lambda c: f"{c}的卫星图",
367
+ lambda c: f"一张{c}的高清卫星图",
368
+ lambda c: f"{c}的高清卫星图",
369
+ lambda c: f"一张{c}的清晰的卫星图",
370
+ lambda c: f"{c}的清晰的卫星图",
371
+ lambda c: f"一张{c}的高质量的卫星图",
372
+ lambda c: f"{c}的高质量的卫星图",
373
+ lambda c: f"一张{c}的卫星图片",
374
+ lambda c: f"{c}的卫星图片",
375
+ lambda c: f"一张{c}的高清卫星图片",
376
+ lambda c: f"{c}的高清卫星图片",
377
+ lambda c: f"一张{c}的清晰的卫星图片",
378
+ lambda c: f"{c}的清晰的卫星图片",
379
+ lambda c: f"一张{c}的高质量的卫星图片",
380
+ lambda c: f"{c}的高质量的卫星图片",
381
+ ]
382
+
383
+ hatefulmemes_templates = [
384
+ lambda c: f"一个{c}",
385
+ lambda c: f"{c}",
386
+ ]
387
+
388
+ kitti_templates = [
389
+ lambda c: f"照片里{c}",
390
+ lambda c: f"图片里{c}",
391
+ lambda c: f"{c}",
392
+ ]
393
+
394
+ cars_templates = [
395
+ lambda c: f"一张{c}的照片",
396
+ lambda c: f"一张我的{c}的照片",
397
+ lambda c: f"我爱我的{c}",
398
+ lambda c: f"一张我肮脏的{c}的照片",
399
+ lambda c: f"一张我干净的{c}的照片",
400
+ lambda c: f"一张我新买的{c}的照片",
401
+ lambda c: f"一张我旧的{c}的照片",
402
+ ]
403
+
404
+ dtd_templates = [
405
+ lambda c: f"一张{c}纹理的照片",
406
+ lambda c: f"一张{c}图案的照片",
407
+ lambda c: f"一张{c}物体的照片",
408
+ lambda c: f"一张{c}纹理的图片",
409
+ lambda c: f"一张{c}图案的图片",
410
+ lambda c: f"一张{c}物体的图片",
411
+ ]
412
+
413
+ country211_templates = [
414
+ lambda c: f"一张在{c}拍的照片",
415
+ lambda c: f"一张在{c}旅行时拍的照片",
416
+ lambda c: f"一张我家乡{c}的照片",
417
+ lambda c: f"一张展示{c}风光的照片",
418
+ ]
419
+
420
+ patch_templates = [
421
+ lambda c: f"一张{c}的医疗照片",
422
+ lambda c: f"一张{c}的ct照片",
423
+ lambda c: f"一张{c}的化验照片",
424
+ ]
425
+
426
+ pet_templates = [
427
+ lambda c: f"一种叫{c}的宠物的照片",
428
+ lambda c: f"一种叫{c}的宠物的图片",
429
+ lambda c: f"一种叫{c}的宠物的可爱图片",
430
+ lambda c: f"一种叫{c}的宠物的高清图片",
431
+ lambda c: f"一种叫{c}的宠物的模糊图片",
432
+ lambda c: f"一种叫{c}的宠物的特写照片",
433
+ ]
434
+
435
+ cifar100_templates = [
436
+ lambda c: f"一张{c}的照片",
437
+ lambda c: f"一张{c}的模糊照片",
438
+ lambda c: f"一张{c}",
439
+ lambda c: f"一张{c}的低对比度照片",
440
+ lambda c: f"一张{c}的高对比度照片",
441
+ lambda c: f"一张{c}的好照片",
442
+ lambda c: f"一张小{c}的照片",
443
+ lambda c: f"一张大{c}的照片",
444
+ lambda c: f"一张{c}的黑白照片",
445
+ lambda c: f"一张{c}的低对比度的照片",
446
+ lambda c: f"一张{c}的高对比度的照片",
447
+ ]
448
+
449
+ caltech101_templates = [
450
+ lambda c: f"{c}的照片",
451
+ lambda c: f"{c}的绘画",
452
+ lambda c: f"{c}的塑料",
453
+ lambda c: f"{c}的雕像",
454
+ lambda c: f"{c}的草图",
455
+ lambda c: f"{c}的刺青",
456
+ lambda c: f"{c}的玩具",
457
+ lambda c: f"{c}的演绎",
458
+ lambda c: f"{c}的装饰",
459
+ lambda c: f"{c}的卡通画",
460
+ lambda c: f"{c}在游戏中",
461
+ lambda c: f"一个豪华的{c}.",
462
+ lambda c: f"{c}的折纸",
463
+ lambda c: f"{c}的艺术画",
464
+ lambda c: f"{c}的涂鸦画",
465
+ lambda c: f"{c}的画",
466
+ ]
467
+
468
+ fer_templates = [
469
+ lambda c: f"一张表情{c}的照片",
470
+ lambda c: f"一张表达{c}情绪的照片",
471
+ lambda c: f"一张看起来很{c}的照片",
472
+ lambda c: f"他的脸看起来{c}",
473
+ lambda c: f"他们看起来很{c}",
474
+ ]
eval/data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ import base64
8
+ from io import BytesIO
9
+ import torch
10
+ import lmdb
11
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
12
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.utils.data.sampler import SequentialSampler
15
+ import torchvision.datasets as datasets
16
+ from clip import tokenize
17
+
18
+
19
+ def _convert_to_rgb(image):
20
+ return image.convert('RGB')
21
+
22
+
23
+ def _preprocess_text(text):
24
+ # adapt the text to Chinese BERT vocab
25
+ text = text.lower().replace("“", "\"").replace("”", "\"")
26
+ return text
27
+
28
+
29
+ class EvalTxtDataset(Dataset):
30
+ def __init__(self, jsonl_filename, max_txt_length=24):
31
+ assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename)
32
+
33
+ logging.debug(f'Loading jsonl data from {jsonl_filename}.')
34
+ self.texts = []
35
+ with open(jsonl_filename, "r", encoding="utf-8") as fin:
36
+ for line in fin:
37
+ obj = json.loads(line.strip())
38
+ text_id = obj['text_id']
39
+ text = obj['text']
40
+ self.texts.append((text_id, text))
41
+ logging.debug(f'Finished loading jsonl data from {jsonl_filename}.')
42
+
43
+ self.max_txt_length = max_txt_length
44
+
45
+ def __len__(self):
46
+ return len(self.texts)
47
+
48
+ def __getitem__(self, idx):
49
+ text_id, text = self.texts[idx]
50
+ text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0]
51
+ return text_id, text
52
+
53
+
54
+ class EvalImgDataset(Dataset):
55
+ def __init__(self, lmdb_imgs, resolution=224):
56
+ assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs)
57
+
58
+ logging.debug(f'Loading image LMDB from {lmdb_imgs}.')
59
+
60
+ self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
61
+ self.txn_imgs = self.env_imgs.begin(buffers=True)
62
+ self.cursor_imgs = self.txn_imgs.cursor()
63
+ self.iter_imgs = iter(self.cursor_imgs)
64
+ self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8'))
65
+ logging.info("The specified LMDB directory contains {} images.".format(self.number_images))
66
+
67
+ self.transform = self._build_transform(resolution)
68
+
69
+ def _build_transform(self, resolution):
70
+ normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
71
+ return Compose([
72
+ Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC),
73
+ _convert_to_rgb,
74
+ ToTensor(),
75
+ normalize,
76
+ ])
77
+
78
+ def __len__(self):
79
+ return self.number_images
80
+
81
+ def __getitem__(self, idx):
82
+ img_id, image_b64 = next(self.iter_imgs)
83
+ if img_id == b"num_images":
84
+ img_id, image_b64 = next(self.iter_imgs)
85
+
86
+ img_id = img_id.tobytes()
87
+ image_b64 = image_b64.tobytes()
88
+
89
+ img_id = int(img_id.decode(encoding="utf8", errors="ignore"))
90
+ image_b64 = image_b64.decode(encoding="utf8", errors="ignore")
91
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) # already resized
92
+ image = self.transform(image)
93
+
94
+ return img_id, image
95
+
96
+
97
+ @dataclass
98
+ class DataInfo:
99
+ dataloader: DataLoader
100
+ sampler: DistributedSampler
101
+
102
+
103
+ def get_eval_txt_dataset(args, max_txt_length=24):
104
+ input_filename = args.text_data
105
+ dataset = EvalTxtDataset(
106
+ input_filename,
107
+ max_txt_length=max_txt_length)
108
+ num_samples = len(dataset)
109
+ sampler = SequentialSampler(dataset)
110
+
111
+ dataloader = DataLoader(
112
+ dataset,
113
+ batch_size=args.text_batch_size,
114
+ num_workers=0,
115
+ pin_memory=True,
116
+ sampler=sampler,
117
+ drop_last=False,
118
+ )
119
+ dataloader.num_samples = num_samples
120
+ dataloader.num_batches = len(dataloader)
121
+
122
+ return DataInfo(dataloader, sampler)
123
+
124
+
125
+ def fetch_resolution(vision_model):
126
+ # fetch the resolution from the vision model config
127
+ vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json"
128
+ with open(vision_model_config_file, 'r') as fv:
129
+ model_info = json.load(fv)
130
+ return model_info["image_resolution"]
131
+
132
+
133
+ def get_eval_img_dataset(args):
134
+ lmdb_imgs = args.image_data
135
+ dataset = EvalImgDataset(
136
+ lmdb_imgs, resolution=fetch_resolution(args.vision_model))
137
+ num_samples = len(dataset)
138
+ sampler = SequentialSampler(dataset)
139
+
140
+ dataloader = DataLoader(
141
+ dataset,
142
+ batch_size=args.img_batch_size,
143
+ num_workers=0,
144
+ pin_memory=True,
145
+ sampler=sampler,
146
+ drop_last=False,
147
+ )
148
+ dataloader.num_samples = num_samples
149
+ dataloader.num_batches = len(dataloader)
150
+
151
+ return DataInfo(dataloader, sampler)
152
+
153
+
154
+ def get_zeroshot_dataset(args, preprocess_fn):
155
+ dataset = datasets.ImageFolder(args.datapath, transform=preprocess_fn)
156
+
157
+ dataloader = torch.utils.data.DataLoader(
158
+ dataset,
159
+ batch_size=args.img_batch_size,
160
+ num_workers=args.num_workers,
161
+ sampler=None,
162
+ )
163
+
164
+ return DataInfo(dataloader, None)
eval/evaluation.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script computes the recall scores given the ground-truth annotations and predictions.
4
+ '''
5
+
6
+ import json
7
+ import sys
8
+ import os
9
+ import string
10
+ import numpy as np
11
+ import time
12
+
13
+ NUM_K = 10
14
+
15
+ def read_submission(submit_path, reference, k=5):
16
+ # check whether the path of submitted file exists
17
+ if not os.path.exists(submit_path):
18
+ raise Exception("The submission file is not found!")
19
+
20
+ submission_dict = {}
21
+ ref_qids = set(reference.keys())
22
+
23
+ with open(submit_path, encoding="utf-8") as fin:
24
+ for line in fin:
25
+ line = line.strip()
26
+ try:
27
+ pred_obj = json.loads(line)
28
+ except:
29
+ raise Exception('Cannot parse this line into json object: {}'.format(line))
30
+ if "text_id" not in pred_obj:
31
+ raise Exception('There exists one line not containing text_id: {}'.format(line))
32
+ if not isinstance(pred_obj['text_id'], int):
33
+ raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
34
+ qid = pred_obj["text_id"]
35
+ if "image_ids" not in pred_obj:
36
+ raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
37
+ image_ids = pred_obj["image_ids"]
38
+ if not isinstance(image_ids, list):
39
+ raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
40
+ # check whether there are K products for each text
41
+ if len(image_ids) != k:
42
+ raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))
43
+ # check whether there exist an invalid prediction for any text
44
+ for rank, image_id in enumerate(image_ids):
45
+ if not isinstance(image_id, int):
46
+ raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
47
+ # check whether there are duplicate predicted products for a single text
48
+ if len(set(image_ids)) != k:
49
+ raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
50
+ submission_dict[qid] = image_ids # here we save the list of product ids
51
+
52
+ # check if any text is missing in the submission
53
+ pred_qids = set(submission_dict.keys())
54
+ nopred_qids = ref_qids - pred_qids
55
+ if len(nopred_qids) != 0:
56
+ raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))
57
+
58
+ return submission_dict
59
+
60
+
61
+ def dump_2_json(info, path):
62
+ with open(path, 'w', encoding="utf-8") as output_json_file:
63
+ json.dump(info, output_json_file)
64
+
65
+
66
+ def report_error_msg(detail, showMsg, out_p):
67
+ error_dict=dict()
68
+ error_dict['errorDetail']=detail
69
+ error_dict['errorMsg']=showMsg
70
+ error_dict['score']=0
71
+ error_dict['scoreJson']={}
72
+ error_dict['success']=False
73
+ dump_2_json(error_dict,out_p)
74
+
75
+
76
+ def report_score(r1, r5, r10, out_p):
77
+ result = dict()
78
+ result['success']=True
79
+ mean_recall = (r1 + r5 + r10) / 3.0
80
+ result['score'] = mean_recall * 100
81
+ result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
82
+ dump_2_json(result,out_p)
83
+
84
+
85
+ def read_reference(path):
86
+ fin = open(path, encoding="utf-8")
87
+ reference = dict()
88
+ for line in fin:
89
+ line = line.strip()
90
+ obj = json.loads(line)
91
+ reference[obj['text_id']] = obj['image_ids']
92
+ return reference
93
+
94
+ def compute_score(golden_file, predict_file):
95
+ # read ground-truth
96
+ reference = read_reference(golden_file)
97
+
98
+ # read predictions
99
+ k = 10
100
+ predictions = read_submission(predict_file, reference, k)
101
+
102
+ # compute score for each text
103
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
104
+ for qid in reference.keys():
105
+ ground_truth_ids = set(reference[qid])
106
+ top10_pred_ids = predictions[qid]
107
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
108
+ r1_stat += 1
109
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
110
+ r5_stat += 1
111
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
112
+ r10_stat += 1
113
+ # the higher score, the better
114
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
115
+ mean_recall = (r1 + r5 + r10) / 3.0
116
+ result = [mean_recall, r1, r5, r10]
117
+ result = [score * 100 for score in result]
118
+ return result
119
+
120
+
121
+ if __name__=="__main__":
122
+ # the path of answer json file (eg. test_queries_answers.jsonl)
123
+ standard_path = sys.argv[1]
124
+ # the path of prediction file (eg. example_pred.jsonl)
125
+ submit_path = sys.argv[2]
126
+ # the score will be dumped into this output json file
127
+ out_path = sys.argv[3]
128
+
129
+ print("Read standard from %s" % standard_path)
130
+ print("Read user submit file from %s" % submit_path)
131
+
132
+ try:
133
+ # read ground-truth
134
+ reference = read_reference(standard_path)
135
+
136
+ # read predictions
137
+ k = 10
138
+ predictions = read_submission(submit_path, reference, k)
139
+
140
+ # compute score for each text
141
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
142
+ for qid in reference.keys():
143
+ ground_truth_ids = set(reference[qid])
144
+ top10_pred_ids = predictions[qid]
145
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
146
+ r1_stat += 1
147
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
148
+ r5_stat += 1
149
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
150
+ r10_stat += 1
151
+ # the higher score, the better
152
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
153
+ report_score(r1, r5, r10, out_path)
154
+ print("The evaluation finished successfully.")
155
+ except Exception as e:
156
+ report_error_msg(e.args[0], e.args[0], out_path)
157
+ print("The evaluation failed: {}".format(e.args[0]))
eval/evaluation_tr.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script computes the recall scores given the ground-truth annotations and predictions.
4
+ '''
5
+
6
+ import json
7
+ import sys
8
+ import os
9
+ import string
10
+ import numpy as np
11
+ import time
12
+
13
+ NUM_K = 10
14
+
15
+ def read_submission(submit_path, reference, k=5):
16
+ # check whether the path of submitted file exists
17
+ if not os.path.exists(submit_path):
18
+ raise Exception("The submission file is not found!")
19
+
20
+ submission_dict = {}
21
+ ref_image_ids = set(reference.keys())
22
+
23
+ with open(submit_path, encoding="utf-8") as fin:
24
+ for line in fin:
25
+ line = line.strip()
26
+ try:
27
+ pred_obj = json.loads(line)
28
+ except:
29
+ raise Exception('Cannot parse this line into json object: {}'.format(line))
30
+ if "image_id" not in pred_obj:
31
+ raise Exception('There exists one line not containing image_id: {}'.format(line))
32
+ if not isinstance(pred_obj['image_id'], int):
33
+ raise Exception('Found an invalid image_id {}, it should be an integer (not string), please check your schema'.format(pred_obj['image_id']))
34
+ image_id = pred_obj['image_id']
35
+ if "text_ids" not in pred_obj:
36
+ raise Exception('There exists one line not containing the predicted text_ids: {}'.format(line))
37
+ text_ids = pred_obj["text_ids"]
38
+ if not isinstance(text_ids, list):
39
+ raise Exception('The text_ids field of image_id {} is not a list, please check your schema'.format(image_id))
40
+ # check whether there are K products for each text
41
+ if len(text_ids) != k:
42
+ raise Exception('Image_id {} has wrong number of predicted text_ids! Require {}, but {} founded.'.format(image_id, k, len(text_ids)))
43
+ # check whether there exist an invalid prediction for any text
44
+ for rank, text_id in enumerate(text_ids):
45
+ if not isinstance(text_id, int):
46
+ raise Exception('Image_id {} has an invalid predicted text_id {} at rank {}, it should be an integer (not string), please check your schema'.format(image_id, text_id, rank + 1))
47
+ # check whether there are duplicate predicted products for a single text
48
+ if len(set(text_ids)) != k:
49
+ raise Exception('Image_id {} has duplicate products in your prediction. Pleace check again!'.format(image_id))
50
+ submission_dict[image_id] = text_ids # here we save the list of product ids
51
+
52
+ # check if any text is missing in the submission
53
+ pred_image_ids = set(submission_dict.keys())
54
+ nopred_image_ids = ref_image_ids - pred_image_ids
55
+ if len(nopred_image_ids) != 0:
56
+ raise Exception('The following image_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_image_ids])))
57
+
58
+ return submission_dict
59
+
60
+
61
+ def dump_2_json(info, path):
62
+ with open(path, 'w', encoding="utf-8") as output_json_file:
63
+ json.dump(info, output_json_file)
64
+
65
+
66
+ def report_error_msg(detail, showMsg, out_p):
67
+ error_dict=dict()
68
+ error_dict['errorDetail']=detail
69
+ error_dict['errorMsg']=showMsg
70
+ error_dict['score']=0
71
+ error_dict['scoreJson']={}
72
+ error_dict['success']=False
73
+ dump_2_json(error_dict,out_p)
74
+
75
+
76
+ def report_score(r1, r5, r10, out_p):
77
+ result = dict()
78
+ result['success']=True
79
+ mean_recall = (r1 + r5 + r10) / 3.0
80
+ result['score'] = mean_recall * 100
81
+ result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
82
+ dump_2_json(result,out_p)
83
+
84
+
85
+ def read_reference(path):
86
+ fin = open(path, encoding="utf-8")
87
+ reference = dict()
88
+ for line in fin:
89
+ line = line.strip()
90
+ obj = json.loads(line)
91
+ reference[obj['image_id']] = obj['text_ids']
92
+ return reference
93
+
94
+ def compute_score(golden_file, predict_file):
95
+ # read ground-truth
96
+ reference = read_reference(golden_file)
97
+
98
+ # read predictions
99
+ k = 10
100
+ predictions = read_submission(predict_file, reference, k)
101
+
102
+ # compute score for each text
103
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
104
+ for qid in reference.keys():
105
+ ground_truth_ids = set(reference[qid])
106
+ top10_pred_ids = predictions[qid]
107
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
108
+ r1_stat += 1
109
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
110
+ r5_stat += 1
111
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
112
+ r10_stat += 1
113
+ # the higher score, the better
114
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
115
+ mean_recall = (r1 + r5 + r10) / 3.0
116
+ result = [mean_recall, r1, r5, r10]
117
+ result = [score * 100 for score in result]
118
+ return result
119
+
120
+
121
+ if __name__=="__main__":
122
+ # the path of answer json file (eg. test_queries_answers.jsonl)
123
+ standard_path = sys.argv[1]
124
+ # the path of prediction file (eg. example_pred.jsonl)
125
+ submit_path = sys.argv[2]
126
+ # the score will be dumped into this output json file
127
+ out_path = sys.argv[3]
128
+
129
+ print("Read standard from %s" % standard_path)
130
+ print("Read user submit file from %s" % submit_path)
131
+
132
+ try:
133
+ # read ground-truth
134
+ reference = read_reference(standard_path)
135
+
136
+ # read predictions
137
+ k = 10
138
+ predictions = read_submission(submit_path, reference, k)
139
+
140
+ # compute score for each text
141
+ r1_stat, r5_stat, r10_stat = 0, 0, 0
142
+ for qid in reference.keys():
143
+ ground_truth_ids = set(reference[qid])
144
+ top10_pred_ids = predictions[qid]
145
+ if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
146
+ r1_stat += 1
147
+ if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
148
+ r5_stat += 1
149
+ if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
150
+ r10_stat += 1
151
+ # the higher score, the better
152
+ r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
153
+ report_score(r1, r5, r10, out_path)
154
+ print("The evaluation finished successfully.")
155
+ except Exception as e:
156
+ report_error_msg(e.args[0], e.args[0], out_path)
157
+ print("The evaluation failed: {}".format(e.args[0]))
eval/extract_features.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script extracts image and text features for evaluation. (with single-GPU)
4
+ '''
5
+
6
+ import os
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+ import json
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from clip.model import convert_weights, CLIP
16
+ from eval.data import get_eval_img_dataset, get_eval_txt_dataset
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ '--extract-image-feats',
22
+ action="store_true",
23
+ default=False,
24
+ help="Whether to extract image features."
25
+ )
26
+ parser.add_argument(
27
+ '--extract-text-feats',
28
+ action="store_true",
29
+ default=False,
30
+ help="Whether to extract text features."
31
+ )
32
+ parser.add_argument(
33
+ '--image-data',
34
+ type=str,
35
+ default="../Multimodal_Retrieval/lmdb/test/imgs",
36
+ help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings."
37
+ )
38
+ parser.add_argument(
39
+ '--text-data',
40
+ type=str,
41
+ default="../Multimodal_Retrieval/test_texts.jsonl",
42
+ help="If --extract-text-feats is True, specify the path of input text Jsonl file."
43
+ )
44
+ parser.add_argument(
45
+ '--image-feat-output-path',
46
+ type=str,
47
+ default=None,
48
+ help="If --extract-image-feats is True, specify the path of output image features."
49
+ )
50
+ parser.add_argument(
51
+ '--text-feat-output-path',
52
+ type=str,
53
+ default=None,
54
+ help="If --extract-image-feats is True, specify the path of output text features."
55
+ )
56
+ parser.add_argument(
57
+ "--img-batch-size", type=int, default=64, help="Image batch size."
58
+ )
59
+ parser.add_argument(
60
+ "--text-batch-size", type=int, default=64, help="Text batch size."
61
+ )
62
+ parser.add_argument(
63
+ "--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)."
64
+ )
65
+ parser.add_argument(
66
+ "--resume",
67
+ default=None,
68
+ type=str,
69
+ help="path to latest checkpoint (default: none)",
70
+ )
71
+ parser.add_argument(
72
+ "--precision",
73
+ choices=["amp", "fp16", "fp32"],
74
+ default="amp",
75
+ help="Floating point precition."
76
+ )
77
+ parser.add_argument(
78
+ "--vision-model",
79
+ choices=["ViT-B-16", "ViT-L-14", "RN50"],
80
+ default="ViT-B-16",
81
+ help="Name of the vision backbone to use.",
82
+ )
83
+ parser.add_argument(
84
+ "--text-model",
85
+ choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
86
+ default="RoBERTa-wwm-ext-base-chinese",
87
+ help="Name of the text backbone to use.",
88
+ )
89
+ parser.add_argument(
90
+ "--debug",
91
+ default=False,
92
+ action="store_true",
93
+ help="If true, more information is logged."
94
+ )
95
+ args = parser.parse_args()
96
+
97
+ return args
98
+
99
+ # Used by https://github.com/openai/CLIP/issues/83 but not below.
100
+ # Keeping it incase needed.
101
+ def convert_models_to_fp32(model):
102
+ for p in model.parameters():
103
+ p.data = p.data.float()
104
+ if p.grad:
105
+ p.grad.data = p.grad.data.float()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ args = parse_args()
110
+
111
+ assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!"
112
+
113
+ # Log params.
114
+ print("Params:")
115
+ for name in sorted(vars(args)):
116
+ val = getattr(args, name)
117
+ print(f" {name}: {val}")
118
+
119
+ args.gpu = 0
120
+ torch.cuda.set_device(args.gpu)
121
+
122
+ # Initialize the model.
123
+ vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
124
+ print('Loading vision model config from', vision_model_config_file)
125
+ assert os.path.exists(vision_model_config_file)
126
+
127
+ text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
128
+ print('Loading text model config from', text_model_config_file)
129
+ assert os.path.exists(text_model_config_file)
130
+
131
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
132
+ model_info = json.load(fv)
133
+ if isinstance(model_info['vision_layers'], str):
134
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
135
+ for k, v in json.load(ft).items():
136
+ model_info[k] = v
137
+
138
+ model = CLIP(**model_info)
139
+ convert_weights(model)
140
+
141
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
142
+ if args.precision == "amp" or args.precision == "fp32":
143
+ convert_models_to_fp32(model)
144
+ model.cuda(args.gpu)
145
+ if args.precision == "fp16":
146
+ convert_weights(model)
147
+
148
+ # Get data.
149
+ if args.extract_image_feats:
150
+ print("Preparing image inference dataset.")
151
+ img_data = get_eval_img_dataset(args)
152
+ if args.extract_text_feats:
153
+ print("Preparing text inference dataset.")
154
+ text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length)
155
+
156
+ # Resume from a checkpoint.
157
+ print("Begin to load model checkpoint from {}.".format(args.resume))
158
+ assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
159
+ # Map model to be loaded to specified single gpu.
160
+ loc = "cuda:{}".format(args.gpu)
161
+ checkpoint = torch.load(args.resume, map_location='cpu')
162
+ start_epoch = checkpoint["epoch"]
163
+ sd = checkpoint["state_dict"]
164
+ if next(iter(sd.items()))[0].startswith('module'):
165
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
166
+ model.load_state_dict(sd)
167
+ print(
168
+ f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
169
+ )
170
+
171
+ # Make inference for texts
172
+ if args.extract_text_feats:
173
+ print('Make inference for texts...')
174
+ if args.text_feat_output_path is None:
175
+ args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6])
176
+ write_cnt = 0
177
+ with open(args.text_feat_output_path, "w") as fout:
178
+ model.eval()
179
+ dataloader = text_data.dataloader
180
+ with torch.no_grad():
181
+ for batch in tqdm(dataloader):
182
+ text_ids, texts = batch
183
+ texts = texts.cuda(args.gpu, non_blocking=True)
184
+ text_features = model(None, texts)
185
+ text_features /= text_features.norm(dim=-1, keepdim=True)
186
+ for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()):
187
+ fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature})))
188
+ write_cnt += 1
189
+ print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path))
190
+
191
+ # Make inference for images
192
+ if args.extract_image_feats:
193
+ print('Make inference for images...')
194
+ if args.image_feat_output_path is None:
195
+ # by default, we store the image features under the same directory with the text features
196
+ args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs"))
197
+ write_cnt = 0
198
+ with open(args.image_feat_output_path, "w") as fout:
199
+ model.eval()
200
+ dataloader = img_data.dataloader
201
+ with torch.no_grad():
202
+ for batch in tqdm(dataloader):
203
+ image_ids, images = batch
204
+ images = images.cuda(args.gpu, non_blocking=True)
205
+ image_features = model(images, None)
206
+ image_features /= image_features.norm(dim=-1, keepdim=True)
207
+ for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()):
208
+ fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature})))
209
+ write_cnt += 1
210
+ print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path))
211
+
212
+ print("Done!")
eval/make_topk_predictions.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs text-to-image prediction file for evaluation.
4
+ '''
5
+
6
+ import argparse
7
+ import numpy
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--image-feats',
18
+ type=str,
19
+ required=True,
20
+ help="Specify the path of image features."
21
+ )
22
+ parser.add_argument(
23
+ '--text-feats',
24
+ type=str,
25
+ required=True,
26
+ help="Specify the path of text features."
27
+ )
28
+ parser.add_argument(
29
+ '--top-k',
30
+ type=int,
31
+ default=10,
32
+ help="Specify the k value of top-k predictions."
33
+ )
34
+ parser.add_argument(
35
+ '--eval-batch-size',
36
+ type=int,
37
+ default=32768,
38
+ help="Specify the image-side batch size when computing the inner products, default to 8192"
39
+ )
40
+ parser.add_argument(
41
+ '--output',
42
+ type=str,
43
+ required=True,
44
+ help="Specify the output jsonl prediction filepath."
45
+ )
46
+ return parser.parse_args()
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+
51
+ # Log params.
52
+ print("Params:")
53
+ for name in sorted(vars(args)):
54
+ val = getattr(args, name)
55
+ print(f" {name}: {val}")
56
+
57
+ print("Begin to load image features...")
58
+ image_ids = []
59
+ image_feats = []
60
+ with open(args.image_feats, "r") as fin:
61
+ for line in tqdm(fin):
62
+ obj = json.loads(line.strip())
63
+ image_ids.append(obj['image_id'])
64
+ image_feats.append(obj['feature'])
65
+ image_feats_array = np.array(image_feats, dtype=np.float32)
66
+ print("Finished loading image features.")
67
+
68
+ print("Begin to compute top-{} predictions for texts...".format(args.top_k))
69
+ with open(args.output, "w") as fout:
70
+ with open(args.text_feats, "r") as fin:
71
+ for line in tqdm(fin):
72
+ obj = json.loads(line.strip())
73
+ text_id = obj['text_id']
74
+ text_feat = obj['feature']
75
+ score_tuples = []
76
+ text_feat_tensor = torch.tensor([text_feat], dtype=torch.float).cuda() # [1, feature_dim]
77
+ idx = 0
78
+ while idx < len(image_ids):
79
+ img_feats_tensor = torch.from_numpy(image_feats_array[idx : min(idx + args.eval_batch_size, len(image_ids))]).cuda() # [batch_size, feature_dim]
80
+ batch_scores = text_feat_tensor @ img_feats_tensor.t() # [1, batch_size]
81
+ for image_id, score in zip(image_ids[idx : min(idx + args.eval_batch_size, len(image_ids))], batch_scores.squeeze(0).tolist()):
82
+ score_tuples.append((image_id, score))
83
+ idx += args.eval_batch_size
84
+ top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
85
+ fout.write("{}\n".format(json.dumps({"text_id": text_id, "image_ids": [entry[0] for entry in top_k_predictions]})))
86
+
87
+ print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
88
+ print("Done!")
eval/make_topk_predictions_tr.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This scripts performs kNN search on inferenced image and text features (on single-GPU) and outputs image-to-text retrieval prediction file for evaluation.
4
+ '''
5
+
6
+ import argparse
7
+ import numpy
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--image-feats',
18
+ type=str,
19
+ required=True,
20
+ help="Specify the path of image features."
21
+ )
22
+ parser.add_argument(
23
+ '--text-feats',
24
+ type=str,
25
+ required=True,
26
+ help="Specify the path of text features."
27
+ )
28
+ parser.add_argument(
29
+ '--top-k',
30
+ type=int,
31
+ default=10,
32
+ help="Specify the k value of top-k predictions."
33
+ )
34
+ parser.add_argument(
35
+ '--eval-batch-size',
36
+ type=int,
37
+ default=32768,
38
+ help="Specify the image-side batch size when computing the inner products, default to 8192"
39
+ )
40
+ parser.add_argument(
41
+ '--output',
42
+ type=str,
43
+ required=True,
44
+ help="Specify the output jsonl prediction filepath."
45
+ )
46
+ return parser.parse_args()
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+
51
+ # Log params.
52
+ print("Params:")
53
+ for name in sorted(vars(args)):
54
+ val = getattr(args, name)
55
+ print(f" {name}: {val}")
56
+
57
+ print("Begin to load text features...")
58
+ text_ids = []
59
+ text_feats = []
60
+ with open(args.text_feats, "r") as fin:
61
+ for line in tqdm(fin):
62
+ obj = json.loads(line.strip())
63
+ text_ids.append(obj['text_id'])
64
+ text_feats.append(obj['feature'])
65
+ text_feats_array = np.array(text_feats, dtype=np.float32)
66
+ print("Finished loading text features.")
67
+
68
+ print("Begin to compute top-{} predictions for images...".format(args.top_k))
69
+ with open(args.output, "w") as fout:
70
+ with open(args.image_feats, "r") as fin:
71
+ for line in tqdm(fin):
72
+ obj = json.loads(line.strip())
73
+ image_id = obj['image_id']
74
+ image_feat = obj['feature']
75
+ score_tuples = []
76
+ image_feat_tensor = torch.tensor([image_feat], dtype=torch.float).cuda() # [1, feature_dim]
77
+ idx = 0
78
+ while idx < len(text_ids):
79
+ text_feats_tensor = torch.from_numpy(text_feats_array[idx : min(idx + args.eval_batch_size, len(text_ids))]).cuda() # [batch_size, feature_dim]
80
+ batch_scores = image_feat_tensor @ text_feats_tensor.t() # [1, batch_size]
81
+ for text_id, score in zip(text_ids[idx : min(idx + args.eval_batch_size, len(text_ids))], batch_scores.squeeze(0).tolist()):
82
+ score_tuples.append((text_id, score))
83
+ idx += args.eval_batch_size
84
+ top_k_predictions = sorted(score_tuples, key=lambda x:x[1], reverse=True)[:args.top_k]
85
+ fout.write("{}\n".format(json.dumps({"image_id": image_id, "text_ids": [entry[0] for entry in top_k_predictions]})))
86
+
87
+ print("Top-{} predictions are saved in {}".format(args.top_k, args.output))
88
+ print("Done!")
eval/transform_ir_annotation_to_tr.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from tqdm import tqdm
3
+ import argparse
4
+ import json
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ '--input',
10
+ type=str,
11
+ required=True,
12
+ help="Input path of text-to-image Jsonl annotation file."
13
+ )
14
+ return parser.parse_args()
15
+
16
+ if __name__ == "__main__":
17
+ args = parse_args()
18
+
19
+ t2i_record = dict()
20
+
21
+ with open(args.input, "r", encoding="utf-8") as fin:
22
+ for line in tqdm(fin):
23
+ obj = json.loads(line.strip())
24
+ text_id = obj['text_id']
25
+ image_ids = obj['image_ids']
26
+ for image_id in image_ids:
27
+ if image_id not in t2i_record:
28
+ t2i_record[image_id] = []
29
+ t2i_record[image_id].append(text_id)
30
+
31
+ with open(args.input.replace(".jsonl", "") + ".tr.jsonl", "w", encoding="utf-8") as fout:
32
+ for image_id, text_ids in t2i_record.items():
33
+ out_obj = {"image_id": image_id, "text_ids": text_ids}
34
+ fout.write("{}\n".format(json.dumps(out_obj)))
35
+
36
+ print("Done!")
eval/zeroshot_evaluation.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ This script performs zero-shot evaluation on ImageNet-1K. (with single-GPU)
4
+ '''
5
+
6
+ import os
7
+ import argparse
8
+ from pathlib import Path
9
+ import json
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+
14
+ from clip.model import convert_weights, CLIP
15
+ from clip import tokenize
16
+ from clip.utils import image_transform
17
+ from eval.data import get_zeroshot_dataset, _preprocess_text
18
+ from eval.cvinw_zeroshot_templates import (
19
+ openai_templates,
20
+ flower_templates,
21
+ food_templates,
22
+ aircraft_templates,
23
+ eurosat_templates,
24
+ country211_templates,
25
+ )
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument(
31
+ "--vision-model",
32
+ choices=["ViT-B-16", "ViT-L-14", "RN50"],
33
+ default="ViT-B-16",
34
+ help="Name of the vision backbone to use.",
35
+ )
36
+ parser.add_argument(
37
+ "--text-model",
38
+ choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
39
+ default="RoBERTa-wwm-ext-base-chinese",
40
+ help="Name of the text backbone to use.",
41
+ )
42
+ parser.add_argument(
43
+ "--precision",
44
+ choices=["amp", "fp16", "fp32"],
45
+ default="amp",
46
+ help="Floating point precition."
47
+ )
48
+ parser.add_argument(
49
+ "--label-file",
50
+ type=str,
51
+ help="file for labels",
52
+ )
53
+ parser.add_argument(
54
+ "--datapath",
55
+ type=str,
56
+ required=True,
57
+ help="Path to the test set for conducting zero shot evaluation.",
58
+ )
59
+ parser.add_argument(
60
+ "--dataset",
61
+ type=str,
62
+ default="imagenet",
63
+ help="Specified dataset.",
64
+ )
65
+ parser.add_argument(
66
+ "--index",
67
+ type=str,
68
+ default="",
69
+ help="Specify image paths.",
70
+ )
71
+ parser.add_argument(
72
+ "--save-dir",
73
+ type=str,
74
+ default="",
75
+ help="Specified dataset.",
76
+ )
77
+ # parser.add_argument(
78
+ # "--imagenet-val",
79
+ # type=str,
80
+ # required=True,
81
+ # help="Path to imagenet val set for conducting zero shot evaluation.",
82
+ # )
83
+ parser.add_argument(
84
+ "--img-batch-size", type=int, default=64, help="Image batch size."
85
+ )
86
+ parser.add_argument(
87
+ "--context-length",
88
+ type=int,
89
+ default=52,
90
+ help="The maximum length of input text (include [CLS] & [SEP] tokens)."
91
+ )
92
+ parser.add_argument(
93
+ "--resume",
94
+ default=None,
95
+ type=str,
96
+ help="path to latest checkpoint (default: none)",
97
+ )
98
+ parser.add_argument(
99
+ "--num-workers", type=int, default=4, help="Number of workers for ImageNet dataloader."
100
+ )
101
+ args = parser.parse_args()
102
+
103
+ return args
104
+
105
+ # Used by https://github.com/openai/CLIP/issues/83 but not below.
106
+ # Keeping it incase needed.
107
+ def convert_models_to_fp32(model):
108
+ for p in model.parameters():
109
+ p.data = p.data.float()
110
+ if p.grad:
111
+ p.grad.data = p.grad.data.float()
112
+
113
+
114
+ def zero_shot_classifier(model, classnames, templates, args):
115
+ with torch.no_grad():
116
+ zeroshot_weights = []
117
+ for classname in tqdm(classnames):
118
+ texts = [_preprocess_text(template(classname)) for template in templates] # format with class
119
+ texts = tokenize(texts, context_length=args.context_length).to(args.gpu) # tokenize
120
+ class_embeddings = model(None, texts)
121
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
122
+ class_embedding = class_embeddings.mean(dim=0)
123
+ class_embedding /= class_embedding.norm()
124
+ zeroshot_weights.append(class_embedding)
125
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu)
126
+ return zeroshot_weights
127
+
128
+
129
+ def accuracy(output, target, topk=(1,)):
130
+ pred = output.topk(max(topk), 1, True, True)[1].t()
131
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
132
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
133
+
134
+
135
+ def run(model, classifier, dataloader, args):
136
+ total_logits = []
137
+ total_targets = []
138
+ with torch.no_grad():
139
+ top1, top5, n = 0.0, 0.0, 0.0
140
+ for images, target in tqdm(dataloader):
141
+ images = images.to(args.gpu)
142
+ target = target.to(args.gpu)
143
+ total_targets.append(target)
144
+
145
+ # predict
146
+ image_features = model(images, None)
147
+ image_features /= image_features.norm(dim=-1, keepdim=True)
148
+ logits = (100.0 * image_features @ classifier).softmax(dim=-1)
149
+ total_logits.append(logits)
150
+
151
+ # measure accuracy
152
+ acc1, acc5 = accuracy(logits, target, topk=(1, 1))
153
+ top1 += acc1
154
+ n += images.size(0)
155
+
156
+ outputs = torch.cat(total_logits, dim=0)
157
+ targets = torch.cat(total_targets, dim=0)
158
+
159
+ if getattr(args, "index", ""):
160
+ print("Use index to rearrange the logits...")
161
+ with open(args.index, "r", encoding="utf-8") as f:
162
+ index = json.load(f)
163
+ print(index)
164
+ outputs = outputs[index]
165
+ targets = targets[index]
166
+ print(targets)
167
+
168
+ top1 = top1 / n
169
+
170
+ return top1, outputs
171
+
172
+
173
+ if __name__ == "__main__":
174
+ args = parse_args()
175
+
176
+ # Log params.
177
+ print("Params:")
178
+ for name in sorted(vars(args)):
179
+ val = getattr(args, name)
180
+ print(f" {name}: {val}")
181
+
182
+ args.gpu = 0
183
+ torch.cuda.set_device(args.gpu)
184
+
185
+ # Initialize the model.
186
+ vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json"
187
+ print('Loading vision model config from', vision_model_config_file)
188
+ assert os.path.exists(vision_model_config_file)
189
+
190
+ text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json"
191
+ print('Loading text model config from', text_model_config_file)
192
+ assert os.path.exists(text_model_config_file)
193
+
194
+ with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
195
+ model_info = json.load(fv)
196
+ if isinstance(model_info['vision_layers'], str):
197
+ model_info['vision_layers'] = eval(model_info['vision_layers'])
198
+ for k, v in json.load(ft).items():
199
+ model_info[k] = v
200
+
201
+ model = CLIP(**model_info)
202
+ convert_weights(model)
203
+
204
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
205
+ if args.precision == "amp" or args.precision == "fp32":
206
+ convert_models_to_fp32(model)
207
+ model.cuda(args.gpu)
208
+ if args.precision == "fp16":
209
+ convert_weights(model)
210
+
211
+ # Get eval data.
212
+ print("Preparing zeroshot dataset.")
213
+ data = {}
214
+ print(f"{model_info['image_resolution']}")
215
+ data[args.dataset] = get_zeroshot_dataset(
216
+ args, image_transform(model_info["image_resolution"])
217
+ )
218
+
219
+ # Resume from a checkpoint.
220
+ print("Begin to load model checkpoint from {}.".format(args.resume))
221
+ assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume)
222
+ # Map model to be loaded to specified single gpu.
223
+ loc = "cuda:{}".format(args.gpu)
224
+ checkpoint = torch.load(args.resume, map_location='cpu')
225
+ start_epoch = checkpoint["epoch"]
226
+ sd = checkpoint["state_dict"]
227
+ if next(iter(sd.items()))[0].startswith('module'):
228
+ sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
229
+ model.load_state_dict(sd, strict=False)
230
+ print(
231
+ f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
232
+ )
233
+
234
+ # Compute ensembled class embeddings
235
+ print('Building zero-shot classifier')
236
+
237
+ model.eval()
238
+
239
+ f = open(args.label_file, "r", encoding="utf8")
240
+ classnames = [line.strip() for line in f.readlines()]
241
+
242
+ template_dict = {
243
+ "fgvc-aircraft-2013b-variants102": aircraft_templates,
244
+ "food-101": food_templates,
245
+ "oxford-flower-102": flower_templates,
246
+ "eurosat_clip": eurosat_templates,
247
+ "resisc45_clip": eurosat_templates,
248
+ "country211": country211_templates,
249
+ "openai": openai_templates,
250
+ }
251
+ if args.dataset in template_dict.keys():
252
+ templates = template_dict[args.dataset]
253
+ else:
254
+ templates = template_dict['openai']
255
+
256
+ # Make inference and evaluation
257
+ print('Using classifier')
258
+ classifier = zero_shot_classifier(model, classnames, templates, args)
259
+ results = {}
260
+ top1, logits = run(model, classifier, data[args.dataset].dataloader, args)
261
+
262
+
263
+ results["zeroshot-top1"] = top1
264
+
265
+ print('Result:')
266
+ print(", ".join(["{}: {}".format(k, v) for k, v in results.items()]))
267
+ print('Finished.')
examples/pokemon.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ tqdm
3
+ six
4
+ timm
5
+ lmdb==1.3.0
6
+ torch>=1.7.1
7
+ torchvision
8
+ webdataset
9
+ pandas
10
+ transformers
scripts/zeroshot_eval.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Usage: see example script below.
4
+ # bash run_scripts/zeroshot_eval.sh 0 \
5
+ # ${path_to_dataset} ${dataset_name} \
6
+ # ViT-B-16 RoBERTa-wwm-ext-base-chinese \
7
+ # ${ckpt_path}
8
+
9
+ # only supports single-GPU inference
10
+ export CUDA_VISIBLE_DEVICES=${1}
11
+ export PYTHONPATH=${PYTHONPATH}:`pwd`/QA-CLIP-main
12
+
13
+ path=${2}
14
+ dataset=${3}
15
+ datapath=${path}
16
+ savedir=`pwd`/save_predictions
17
+ vision_model=${4} # ViT-B-16
18
+ text_model=${5}
19
+ resume=${6}
20
+ label_file=`pwd`/label_cn.txt
21
+ index=${7:-}
22
+
23
+ mkdir -p ${savedir}
24
+
25
+ python -u eval/zeroshot_evaluation.py \
26
+ --datapath="${datapath}" \
27
+ --label-file=${label_file} \
28
+ --save-dir=${savedir} \
29
+ --dataset=${dataset} \
30
+ --index=${index} \
31
+ --img-batch-size=64 \
32
+ --resume=${resume} \
33
+ --vision-model=${vision_model} \
34
+ --text-model=${text_model}