shunxing1234 commited on
Commit
e08f928
1 Parent(s): 9530eee

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +267 -0
README.md CHANGED
@@ -1,3 +1,270 @@
1
  ---
2
  license: creativeml-openrail-m
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: creativeml-openrail-m
3
  ---
4
+
5
+
6
+ # AltCLIP-m18
7
+
8
+
9
+ | 名称 Name | 任务 Task | 语言 Language(s) | 模型 Model | Github |
10
+ |:------------------:|:----------:|:-------------------:|:--------:|:------:|
11
+ | AltCLIP-m18 | Text-Image | Multilingual | CLIP | [FlagAI](https://github.com/FlagAI-Open/FlagAI) |
12
+
13
+ ## 简介 Brief Introduction
14
+
15
+ 继双语模型AltCLIP与9语模型AltCLIP-m9之后,我们训练了18语CLIP模型。命名为AltCLIP-m18。它支持英语、中文、日语、泰语、韩语、印地语、乌克兰语、阿拉伯语、土耳其语、越南语、波兰语、荷兰语、葡萄牙语、意大利语、西班牙语、德语、法语和俄语。
16
+
17
+ AltCLIP-m18模型可以为AltDiffusion-m18模型提供支持,关于AltDiffusion模型的具体信息可查看[此教程](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion/README.md) 。
18
+
19
+ 模型代码已经在 [FlagAI](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP-m18) 上开源,权重位于我们搭建的 [modelhub](https://model.baai.ac.cn/model-detail/100095) 上。我们还提供了微调,推理,验证的脚本,欢迎试用。
20
+
21
+ Following the bilingual model AltCLIP and the nine-language model AltCLIP-m9, we trained the eighteen-language CLIP model, Named AltCLIP-m18. It supports English, Chinese, Japanese, Thai, Korean, Hindi, Ukrainian, Arabic, Turkish, Vietnamese, Polish, Dutch, Portuguese, Italian, Spanish, German, French, and Russian.
22
+
23
+ The AltCLIP-m18 model can provide support for the AltDiffusion-m18 model. Specific information on the AltDiffusion modle can be found in [this tutorial](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion/README.md).
24
+
25
+ The model code has been open sourced on [FlagAI](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP-m18) and the weights are located on [modelhub](https://model.baai.ac.cn/model-detail/100095). We also provide scripts for fine-tuning, inference, and validation, so feel free to try them out.
26
+
27
+ ## 引用
28
+ 关于AltCLIP,我们已经推出了相关报告,有更多细节可以查阅,如对您的工作有帮助,欢迎引用。
29
+
30
+ If you find this work helpful, please consider to cite
31
+ ```
32
+ @article{https://doi.org/10.48550/arxiv.2211.06679,
33
+ doi = {10.48550/ARXIV.2211.06679},
34
+ url = {https://arxiv.org/abs/2211.06679},
35
+ author = {Chen, Zhongzhi and Liu, Guang and Zhang, Bo-Wen and Ye, Fulong and Yang, Qinghong and Wu, Ledell},
36
+ keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences},
37
+ title = {AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities},
38
+ publisher = {arXiv},
39
+ year = {2022},
40
+ copyright = {arXiv.org perpetual, non-exclusive license}
41
+ }
42
+ ```
43
+
44
+ ## AltCLIP-m18评测 AltCLIP-m18 evaluation
45
+
46
+ 部分数据集评测结果展示:
47
+
48
+ Partial dataset evaluation results are displayed:
49
+
50
+ | | birdsnap | caltech101 | cars | cifar10 | cifar100 | country211 | dtd | eurosat | fer2013 |
51
+ | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |
52
+ | AltCLIP-M18 | 41.57 | 88.25 | 92.75 | 97.44 | 84.83 | 30.52 | 68.62 | 67.46 | 54.4 |
53
+
54
+ Cifar10 dataset evaluation
55
+
56
+ ```python
57
+ # Copyright © 2022 BAAI. All rights reserved.
58
+ #
59
+ # Licensed under the Apache License, Version 2.0 (the "License")
60
+ import torch
61
+ from flagai.auto_model.auto_loader import AutoLoader
62
+ import zeroshot_classification
63
+ import json
64
+ import os
65
+ from torchvision.datasets import CIFAR10
66
+
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ maxlen = 256
69
+
70
+ dataset_root = "./clip_benchmark_datasets/"
71
+ dataset_name = "cifar10"
72
+
73
+ auto_loader = AutoLoader(
74
+ task_name="txt_img_matching",
75
+ model_dir="./checkpoints/",
76
+ model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
77
+ )
78
+
79
+ model = auto_loader.get_model()
80
+ model.to(device)
81
+ model.eval()
82
+ tokenizer = auto_loader.get_tokenizer()
83
+ transform = auto_loader.get_transform()
84
+
85
+ dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
86
+ transform=transform,
87
+ download=True)
88
+ batch_size = 128
89
+ num_workers = 4
90
+
91
+ template = {"cifar10": [
92
+ "a photo of a {c}.",
93
+ "a blurry photo of a {c}.",
94
+ "a black and white photo of a {c}.",
95
+ "a low contrast photo of a {c}.",
96
+ "a high contrast photo of a {c}.",
97
+ "a bad photo of a {c}.",
98
+ "a good photo of a {c}.",
99
+ "a photo of a small {c}.",
100
+ "a photo of a big {c}.",
101
+ "a photo of the {c}.",
102
+ "a blurry photo of the {c}.",
103
+ "a black and white photo of the {c}.",
104
+ "a low contrast photo of the {c}.",
105
+ "a high contrast photo of the {c}.",
106
+ "a bad photo of the {c}.",
107
+ "a good photo of the {c}.",
108
+ "a photo of the small {c}.",
109
+ "a photo of the big {c}."
110
+ ],
111
+ }
112
+ def evaluate():
113
+ if dataset:
114
+ dataloader = torch.utils.data.DataLoader(
115
+ dataset,
116
+ batch_size=batch_size,
117
+ shuffle=False,
118
+ num_workers=num_workers,
119
+ )
120
+
121
+ zeroshot_templates = template["cifar10"]
122
+ classnames = dataset.classes if hasattr(dataset, "classes") else None
123
+
124
+ metrics = zeroshot_classification.evaluate(
125
+ model,
126
+ dataloader,
127
+ tokenizer,
128
+ classnames,
129
+ zeroshot_templates,
130
+ device=device,
131
+ amp=True,
132
+ )
133
+
134
+ dump = {
135
+ "dataset": dataset_name,
136
+ "metrics": metrics
137
+ }
138
+
139
+ print(dump)
140
+ with open("./result.txt", "w") as f:
141
+ json.dump(dump, f)
142
+ return metrics
143
+
144
+ if __name__ == "__main__":
145
+ evaluate()
146
+
147
+ ```
148
+
149
+
150
+
151
+ ## 推理脚本 inference
152
+
153
+ ```python
154
+ import torch
155
+ from PIL import Image
156
+ from flagai.auto_model.auto_loader import AutoLoader
157
+
158
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
159
+
160
+ loader = AutoLoader(
161
+ task_name="txt_img_matching",
162
+ model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
163
+ model_dir="./checkpoints"
164
+ )
165
+
166
+ model = loader.get_model()
167
+ tokenizer = loader.get_tokenizer()
168
+ transform = loader.get_transform()
169
+
170
+ model.eval()
171
+ model.to(device)
172
+ tokenizer = loader.get_tokenizer()
173
+
174
+ def inference():
175
+ image = Image.open("./dog.jpeg")
176
+ image = transform(image)
177
+ image = torch.tensor(image["pixel_values"]).to(device)
178
+ tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
179
+ padding=True,
180
+ truncation=True,
181
+ max_length=77,
182
+ return_tensors='pt')
183
+
184
+ text = tokenizer_out["input_ids"].to(device)
185
+ attention_mask = tokenizer_out["attention_mask"].to(device)
186
+ with torch.no_grad():
187
+ image_features = model.get_image_features(image)
188
+ text_features = model.get_text_features(text, attention_mask=attention_mask)
189
+ text_probs = (image_features @ text_features.T).softmax(dim=-1)
190
+
191
+ print(text_probs.cpu().numpy()[0].tolist())
192
+
193
+ if __name__=="__main__":
194
+ inference()
195
+ ```
196
+
197
+
198
+
199
+ ## 微调 fintuning
200
+
201
+ Cifar10 dataset
202
+
203
+ ```python
204
+ # Copyright © 2022 BAAI. All rights reserved.
205
+ #
206
+ # Licensed under the Apache License, Version 2.0 (the "License")
207
+ import torch
208
+ from flagai.auto_model.auto_loader import AutoLoader
209
+ import os
210
+ from flagai.trainer import Trainer
211
+ from torchvision.datasets import (
212
+ CIFAR10
213
+ )
214
+
215
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
216
+
217
+ dataset_root = "./clip_benchmark_datasets"
218
+ dataset_name = "cifar10"
219
+
220
+ batch_size = 4
221
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
222
+
223
+ auto_loader = AutoLoader(
224
+ task_name="txt_img_matching",
225
+ model_dir="./checkpoints",
226
+ model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
227
+ )
228
+
229
+ model = auto_loader.get_model()
230
+ model.to(device)
231
+ model.eval()
232
+ tokenizer = auto_loader.get_tokenizer()
233
+ transform = auto_loader.get_transform()
234
+
235
+ trainer = Trainer(env_type="pytorch",
236
+ pytorch_device=device,
237
+ experiment_name="clip_finetuning",
238
+ batch_size=4,
239
+ lr=1e-4,
240
+ epochs=10,
241
+ log_interval=10)
242
+
243
+ dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
244
+ transform=transform,
245
+ download=True)
246
+
247
+ def cifar10_collate_fn(batch):
248
+ # image shape is (batch, 3, 224, 224)
249
+ images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
250
+ # text_id shape is (batch, n)
251
+ input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",
252
+ padding=True,
253
+ truncation=True,
254
+ max_length=77)["input_ids"] for b in batch])
255
+
256
+ attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}",
257
+ padding=True,
258
+ truncation=True,
259
+ max_length=77)["attention_mask"] for b in batch])
260
+
261
+ return {
262
+ "pixel_values": images,
263
+ "input_ids": input_ids,
264
+ "attention_mask": attention_mask,
265
+ }
266
+
267
+ if __name__ == "__main__":
268
+ trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
269
+ ```
270
+