ZacLiu commited on
Commit
6506e5f
1 Parent(s): 2466802

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -197
README.md CHANGED
@@ -249,209 +249,27 @@ Based on AltCLIP, we have also developed the AltDiffusion model, visualized as f
249
  ![](https://raw.githubusercontent.com/920232796/test/master/image7.png)
250
 
251
  ## 模型推理 Inference
252
-
253
  ```python
254
- import torch
255
  from PIL import Image
256
- from flagai.auto_model.auto_loader import AutoLoader
257
-
258
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
- ## 一行代码直接自动下载权重到'./checkpoints/clip-xlmr-large',并自动加载CLIP模型权重
260
- ## modelhub地址: Modelhub(https://model.baai.ac.cn/models)
261
- loader = AutoLoader(
262
- task_name="txt_img_matching",
263
- model_dir="./checkpoints",
264
- model_name="AltCLIP-XLMR-L"
265
- )
266
- ## 获取加载好的模型
267
- model = loader.get_model()
268
- ## 获取tokenizer
269
- tokenizer = loader.get_tokenizer()
270
- ## 获取transform用来处理图像
271
- transform = loader.get_transform()
272
-
273
- model.eval()
274
- model.to(device)
275
-
276
- ## 推理过程,图像与文本匹配
277
- image = Image.open("./dog.jpeg")
278
- image = transform(image)
279
- image = torch.tensor(image["pixel_values"]).to(device)
280
- text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"]
281
-
282
- text = torch.tensor(text).to(device)
283
-
284
- with torch.no_grad():
285
- image_features = model.get_image_features(image)
286
- text_features = model.get_text_features(text)
287
- text_probs = (image_features @ text_features.T).softmax(dim=-1)
288
-
289
- print(text_probs.cpu().numpy()[0].tolist())
290
- ```
291
-
292
- ## CLIP微调 Finetuning
293
-
294
- 微调采用cifar10数据集,并使用FlagAI的Trainer快速开始训练过程。
295
-
296
- Fine-tuning was done using the cifar10 dataset and using FlagAI's Trainer to quickly start the training process.
297
-
298
- ```python
299
- # Copyright © 2022 BAAI. All rights reserved.
300
- #
301
- # Licensed under the Apache License, Version 2.0 (the "License")
302
- import torch
303
- from flagai.auto_model.auto_loader import AutoLoader
304
- import os
305
- from flagai.trainer import Trainer
306
- from torchvision.datasets import (
307
- CIFAR10
308
- )
309
-
310
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311
- dataset_root = "./clip_benchmark_datasets"
312
- dataset_name = "cifar10"
313
-
314
- batch_size = 4
315
- classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
316
-
317
- auto_loader = AutoLoader(
318
- task_name="txt_img_matching",
319
- model_dir="./checkpoints/",
320
- model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
321
- )
322
-
323
- model = auto_loader.get_model()
324
- model.to(device)
325
- model.eval()
326
- tokenizer = auto_loader.get_tokenizer()
327
- transform = auto_loader.get_transform()
328
-
329
- trainer = Trainer(env_type="pytorch",
330
- pytorch_device=device,
331
- experiment_name="clip_finetuning",
332
- batch_size=4,
333
- lr=1e-4,
334
- epochs=10,
335
- log_interval=10)
336
-
337
- dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
338
- transform=transform,
339
- download=True)
340
-
341
- def cifar10_collate_fn(batch):
342
- # image shape is (batch, 3, 224, 224)
343
- images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
344
- # text_id shape is (batch, n)
345
- input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch])
346
-
347
- return {
348
- "pixel_values": images,
349
- "input_ids": input_ids
350
- }
351
-
352
- if __name__ == "__main__":
353
- trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
354
- ```
355
-
356
 
 
 
 
357
 
358
- ## 模型验证 Evaluation
 
 
359
 
360
- 我们提供了可以直接运行的验证脚本,在cifar10数据集上进行验证。
 
361
 
362
- 期待的输出为:```{'dataset': 'cifar10', 'metrics': {'acc1': 0.95402, 'acc5': 0.99616, 'mean_per_class_recall': 0.9541200000000002}}```
363
-
364
- We provide validation scripts that can be run directly on the cifar10 dataset.
365
-
366
- ```python
367
- # Copyright © 2022 BAAI. All rights reserved.
368
- #
369
- # Licensed under the Apache License, Version 2.0 (the "License")
370
- import torch
371
- from flagai.auto_model.auto_loader import AutoLoader
372
- from metrics import zeroshot_classification
373
- import json
374
- import os
375
- from torchvision.datasets import CIFAR10
376
-
377
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
378
- maxlen = 256
379
-
380
- dataset_root = "./clip_benchmark_datasets"
381
- dataset_name = "cifar10"
382
-
383
- auto_loader = AutoLoader(
384
- task_name="txt_img_matching",
385
- model_dir="./checkpoints/",
386
- model_name="AltCLIP-XLMR-L"
387
- )
388
-
389
- model = auto_loader.get_model()
390
- model.to(device)
391
- model.eval()
392
- tokenizer = auto_loader.get_tokenizer()
393
- transform = auto_loader.get_transform()
394
-
395
- dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
396
- transform=transform,
397
- download=True)
398
- batch_size = 128
399
- num_workers = 4
400
-
401
- template = {"cifar10": [
402
- "a photo of a {c}.",
403
- "a blurry photo of a {c}.",
404
- "a black and white photo of a {c}.",
405
- "a low contrast photo of a {c}.",
406
- "a high contrast photo of a {c}.",
407
- "a bad photo of a {c}.",
408
- "a good photo of a {c}.",
409
- "a photo of a small {c}.",
410
- "a photo of a big {c}.",
411
- "a photo of the {c}.",
412
- "a blurry photo of the {c}.",
413
- "a black and white photo of the {c}.",
414
- "a low contrast photo of the {c}.",
415
- "a high contrast photo of the {c}.",
416
- "a bad photo of the {c}.",
417
- "a good photo of the {c}.",
418
- "a photo of the small {c}.",
419
- "a photo of the big {c}."
420
- ],
421
- }
422
- def evaluate():
423
- if dataset:
424
- dataloader = torch.utils.data.DataLoader(
425
- dataset,
426
- batch_size=batch_size,
427
- shuffle=False,
428
- num_workers=num_workers,
429
- )
430
- classnames = dataset.classes if hasattr(dataset, "classes") else None
431
-
432
- zeroshot_templates = template["cifar10"]
433
- metrics = zeroshot_classification.evaluate(
434
- model,
435
- dataloader,
436
- tokenizer,
437
- classnames,
438
- zeroshot_templates,
439
- device=device,
440
- amp=True,
441
- )
442
-
443
- dump = {
444
- "dataset": dataset_name,
445
- "metrics": metrics
446
- }
447
-
448
- print(dump)
449
- with open("./result.txt", "w") as f:
450
- json.dump(dump, f)
451
- return metrics
452
-
453
- if __name__ == "__main__":
454
- evaluate()
455
 
 
 
 
456
  ```
457
 
 
 
249
  ![](https://raw.githubusercontent.com/920232796/test/master/image7.png)
250
 
251
  ## 模型推理 Inference
252
+ Please download the code from [FlagAI AltCLIP](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP)
253
  ```python
 
254
  from PIL import Image
255
+ import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ # transformers version >= 4.21.0
258
+ from modeling_altclip import AltCLIP
259
+ from processing_altclip import AltCLIPProcessor
260
 
261
+ # now our repo's in private, so we need `use_auth_token=True`
262
+ model = AltCLIP.from_pretrained("BAAI/AltCLIP")
263
+ processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
264
 
265
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
266
+ image = Image.open(requests.get(url, stream=True).raw)
267
 
268
+ inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ outputs = model(**inputs)
271
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
272
+ probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
273
  ```
274
 
275
+