ngxquang commited on
Commit
6294617
1 Parent(s): 1775056

feat: upload BeiT3 API for deployment

Browse files
Files changed (40) hide show
  1. .env +10 -0
  2. Dockerfile +39 -0
  3. data/config/keyframes_groups_L01_to_L20.json +3 -0
  4. data/faiss-index/index_beit3_L01_to_L20.faiss +3 -0
  5. download_models.sh +11 -0
  6. requirements.dev.txt +21 -0
  7. requirements.txt +18 -0
  8. src/__init__.py +0 -0
  9. src/__pycache__/config.cpython-311.pyc +0 -0
  10. src/__pycache__/config.cpython-38.pyc +0 -0
  11. src/__pycache__/main.cpython-311.pyc +0 -0
  12. src/__pycache__/main.cpython-38.pyc +0 -0
  13. src/config.py +28 -0
  14. src/itr/__init__.py +0 -0
  15. src/itr/__pycache__/__init__.cpython-311.pyc +0 -0
  16. src/itr/__pycache__/__init__.cpython-38.pyc +0 -0
  17. src/itr/__pycache__/beit3_model.cpython-311.pyc +0 -0
  18. src/itr/__pycache__/beit3_model.cpython-38.pyc +0 -0
  19. src/itr/__pycache__/dtb_cursor.cpython-311.pyc +0 -0
  20. src/itr/__pycache__/dtb_cursor.cpython-38.pyc +0 -0
  21. src/itr/__pycache__/modeling_finetune.cpython-311.pyc +0 -0
  22. src/itr/__pycache__/modeling_finetune.cpython-38.pyc +0 -0
  23. src/itr/__pycache__/modeling_utils.cpython-311.pyc +0 -0
  24. src/itr/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  25. src/itr/__pycache__/router.cpython-311.pyc +0 -0
  26. src/itr/__pycache__/router.cpython-38.pyc +0 -0
  27. src/itr/__pycache__/utils.cpython-311.pyc +0 -0
  28. src/itr/__pycache__/utils.cpython-38.pyc +0 -0
  29. src/itr/__pycache__/vlm_model.cpython-311.pyc +0 -0
  30. src/itr/__pycache__/vlm_model.cpython-38.pyc +0 -0
  31. src/itr/beit3/README.md +28 -0
  32. src/itr/beit3_model.py +109 -0
  33. src/itr/beit3_model/README.md +6 -0
  34. src/itr/dtb_cursor.py +30 -0
  35. src/itr/modeling_finetune.py +388 -0
  36. src/itr/modeling_utils.py +108 -0
  37. src/itr/router.py +49 -0
  38. src/itr/utils.py +891 -0
  39. src/itr/vlm_model.py +31 -0
  40. src/main.py +67 -0
.env ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # PROJECT INFORMATION
2
+ HOST=0.0.0.0
3
+ PORT=7860
4
+ CORS_HEADERS=["*"]
5
+ CORS_ORIGINS=["*"]
6
+
7
+ DEVICE="cpu" # ["cuda", "cpu"]
8
+
9
+ INDEX_FILE_PATH="data/faiss-index/index_beit3_L01_to_L20.faiss"
10
+ KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L20.json"
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+
3
+ RUN apt-get update && \
4
+ apt-get install git bash wget iputils-ping -y && \
5
+ apt clean && \
6
+ rm -rf /var/cache/apt/*
7
+
8
+ WORKDIR /code
9
+
10
+ COPY requirements.txt /code/requirements.txt
11
+
12
+ # PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
13
+ # PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
14
+ # PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
15
+ ENV PYTHONDONTWRITEBYTECODE=1 \
16
+ PYTHONUNBUFFERED=1 \
17
+ PYTHONIOENCODING=utf-8
18
+
19
+ RUN pip install -U pip && \
20
+ python -m pip install -r /code/requirements.txt
21
+
22
+ RUN useradd -m -u 1000 user
23
+
24
+ USER user
25
+
26
+ ENV HOME=/home/user \
27
+ PATH=/home/user/.local/bin:$PATH
28
+
29
+ WORKDIR $HOME/app
30
+
31
+ COPY --chown=user . $HOME/app
32
+
33
+ # Download index
34
+ # RUN mkdir ./data/faiss-index/ && \
35
+ # gsutil cp "gs://thangtd1/faiss-index/index_beit3_L01_to_L20.faiss" ./data/faiss-index/
36
+
37
+ RUN bash $HOME/app/download_models.sh
38
+
39
+ CMD python ./src/main.py
data/config/keyframes_groups_L01_to_L20.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e66e7bd8837d2b9dc2ce3324e19b3346335dc2437ecfd4279138485bc51422f
3
+ size 26271112
data/faiss-index/index_beit3_L01_to_L20.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b5a31a0b7815943a055dfd5d39bf2be3ba33447586354cb2fa6505894b3c832
3
+ size 620998701
download_models.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Test network connection before download
4
+ ping -c 4 8.8.8.8
5
+ if [ $? -gt 0 ]
6
+ then
7
+ exit 1
8
+ fi
9
+ # Download processor and beit-3 model from provided urls
10
+ wget "https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D" -O ./src/itr/beit3_model/beit3.spm
11
+ wget "https://conversationhub.blob.core.windows.net/beit-share-public/beit3/f30k_retrieval/beit3_base_patch16_384_f30k_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D" -O ./src/itr/beit3_model/beit3_base_patch16_384_f30k_retrieval.pth
requirements.dev.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.1
2
+ uvicorn==0.23.2
3
+ pydantic-settings==2.0.3
4
+
5
+
6
+ # Models
7
+ torch==2.0.0
8
+ torchvision==0.15.1
9
+ torchscale==0.2.0
10
+ ftfy==6.1.1
11
+ regex
12
+ tqdm==4.66.1
13
+ transformers==4.33.1
14
+ timm==0.4.12
15
+ sentencepiece==0.1.99
16
+
17
+ # Vector Database
18
+ faiss-cpu
19
+
20
+ # Project settings
21
+ pre-commit
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.1
2
+ uvicorn==0.23.2
3
+ pydantic-settings==2.0.3
4
+
5
+
6
+ # Models
7
+ torch==2.0.0
8
+ torchvision==0.15.1
9
+ torchscale==0.2.0
10
+ ftfy==6.1.1
11
+ regex
12
+ tqdm==4.66.1
13
+ transformers==4.33.1
14
+ timm==0.4.12
15
+ sentencepiece==0.1.99
16
+
17
+ # Vector Database
18
+ faiss-cpu
src/__init__.py ADDED
File without changes
src/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.3 kB). View file
 
src/__pycache__/config.cpython-38.pyc ADDED
Binary file (929 Bytes). View file
 
src/__pycache__/main.cpython-311.pyc ADDED
Binary file (3.49 kB). View file
 
src/__pycache__/main.cpython-38.pyc ADDED
Binary file (2.13 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from pydantic_settings import BaseSettings
4
+
5
+ FILE = Path(__file__)
6
+ ROOT = FILE.parent.parent
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ # API SETTINGS
11
+ HOST: str
12
+ PORT: int
13
+ CORS_ORIGINS: list
14
+ CORS_HEADERS: list
15
+
16
+ # MODEL SETTINGS
17
+ MODEL_NAME: str = "ViT-B/32"
18
+ DEVICE: str = "cpu"
19
+
20
+ # FAISS DATABASE SETTINGS
21
+ INDEX_FILE_PATH: str
22
+ KEYFRAMES_GROUPS_JSON_PATH: str
23
+
24
+ class Config:
25
+ env_file = ROOT / ".env"
26
+
27
+
28
+ settings = Settings()
src/itr/__init__.py ADDED
File without changes
src/itr/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (214 Bytes). View file
 
src/itr/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (196 Bytes). View file
 
src/itr/__pycache__/beit3_model.cpython-311.pyc ADDED
Binary file (6.17 kB). View file
 
src/itr/__pycache__/beit3_model.cpython-38.pyc ADDED
Binary file (3.4 kB). View file
 
src/itr/__pycache__/dtb_cursor.cpython-311.pyc ADDED
Binary file (2.74 kB). View file
 
src/itr/__pycache__/dtb_cursor.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
src/itr/__pycache__/modeling_finetune.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
src/itr/__pycache__/modeling_finetune.cpython-38.pyc ADDED
Binary file (10.4 kB). View file
 
src/itr/__pycache__/modeling_utils.cpython-311.pyc ADDED
Binary file (4.86 kB). View file
 
src/itr/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (3.09 kB). View file
 
src/itr/__pycache__/router.cpython-311.pyc ADDED
Binary file (2.39 kB). View file
 
src/itr/__pycache__/router.cpython-38.pyc ADDED
Binary file (1.51 kB). View file
 
src/itr/__pycache__/utils.cpython-311.pyc ADDED
Binary file (48.2 kB). View file
 
src/itr/__pycache__/utils.cpython-38.pyc ADDED
Binary file (24.8 kB). View file
 
src/itr/__pycache__/vlm_model.cpython-311.pyc ADDED
Binary file (2.83 kB). View file
 
src/itr/__pycache__/vlm_model.cpython-38.pyc ADDED
Binary file (1.56 kB). View file
 
src/itr/beit3/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Using BEiT-3 to get text-vision embedding
2
+
3
+ ## For text embedding
4
+ 1. Create file ```test_model.py``` inside folder ```itr```.
5
+ 2. Using the code follow:
6
+ ```
7
+ from beit3_model import Beit3Model
8
+
9
+ if __name__ == '__main__':
10
+ vlm = Beit3Model(device='cpu')
11
+
12
+ print(vlm.get_embedding('A man who loves a girl.').shape)
13
+ ```
14
+
15
+ ## For image embedding
16
+ 1. Create file ```test_model.py``` inside folder ```itr```.
17
+ 2. Using the code follow:
18
+ ```
19
+ from beit3_model import Beit3Model
20
+ from torchvision.datasets.folder import default_loader
21
+
22
+ if __name__ == '__main__':
23
+ loader = default_loader
24
+ image = loader('./path/to/your/image.jpg')
25
+
26
+ vlm = Beit3Model(device='cpu')
27
+ print(vlm.get_embedding(image).shape)
28
+ ```
src/itr/beit3_model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from functools import lru_cache
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ from . import modeling_finetune, utils
8
+ from PIL import Image
9
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10
+ from timm.models import create_model
11
+ from torchvision import transforms
12
+ from transformers import XLMRobertaTokenizer
13
+
14
+ # Get current workdir of this file
15
+ CWD = Path(__file__).parent
16
+ print(CWD)
17
+
18
+
19
+ class Preprocess:
20
+ def __init__(self, tokenizer):
21
+ self.max_len = 64
22
+ self.input_size = 384
23
+
24
+ self.tokenizer = tokenizer
25
+ self.transform = transforms.Compose(
26
+ [
27
+ transforms.Resize((self.input_size, self.input_size), interpolation=3),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(
30
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
31
+ ),
32
+ ]
33
+ )
34
+
35
+ self.bos_token_id = tokenizer.bos_token_id
36
+ self.eos_token_id = tokenizer.eos_token_id
37
+ self.pad_token_id = tokenizer.pad_token_id
38
+
39
+ def preprocess(self, input: Union[str, Image.Image]):
40
+ if isinstance(input, str):
41
+ tokens = self.tokenizer.tokenize(input)
42
+ tokens = self.tokenizer.convert_tokens_to_ids(tokens)
43
+
44
+ tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id]
45
+ num_tokens = len(tokens)
46
+ padding_mask = [0] * num_tokens + [1] * (self.max_len - num_tokens)
47
+
48
+ return (
49
+ torch.LongTensor(
50
+ tokens + [self.pad_token_id] * (self.max_len - num_tokens)
51
+ ).unsqueeze(0),
52
+ torch.Tensor(padding_mask).unsqueeze(0),
53
+ num_tokens,
54
+ )
55
+ elif isinstance(input, Image.Image):
56
+ return self.transform(input).unsqueeze(0)
57
+ else:
58
+ raise Exception("Invalid input type")
59
+
60
+
61
+ class Beit3Model:
62
+ def __init__(
63
+ self,
64
+ model_name: str = "beit3_base_patch16_384_retrieval",
65
+ model_path: str = os.path.join(
66
+ CWD,
67
+ "beit3_model/beit3_base_patch16_384_f30k_retrieval.pth",
68
+ ),
69
+ device: str = "cuda",
70
+ ):
71
+ self._load_model(model_name, model_path, device)
72
+ self.device = device
73
+
74
+ # @lru_cache(maxsize=1)
75
+ def _load_model(self, model_name, model_path, device: str = "cpu"):
76
+ self.model = create_model(
77
+ model_name,
78
+ pretrained=False,
79
+ drop_path_rate=0.1,
80
+ vocab_size=64010,
81
+ checkpoint_activations=False,
82
+ )
83
+
84
+ if model_name:
85
+ utils.load_model_and_may_interpolate(
86
+ model_path, self.model, "model|module", ""
87
+ )
88
+
89
+ self.preprocessor = Preprocess(
90
+ XLMRobertaTokenizer(os.path.join(CWD, "beit3_model/beit3.spm"))
91
+ )
92
+ self.model.to(device)
93
+
94
+ def get_embedding(self, input: Union[str, Image.Image]):
95
+ if isinstance(input, str):
96
+ token_ids, padding_mask, _ = self.preprocessor.preprocess(input)
97
+
98
+ _, vector = self.model(
99
+ text_description=token_ids, padding_mask=padding_mask, only_infer=True
100
+ )
101
+ vector = vector.cpu().detach().numpy().astype("float32")
102
+ return vector
103
+ elif isinstance(input, Image.Image):
104
+ image_input = self.preprocessor.preprocess(input)
105
+ image_input = image_input.to(self.device)
106
+ vector, _ = self.model(image=image_input, only_infer=True)
107
+ return vector
108
+ else:
109
+ raise Exception("Invalid input type")
src/itr/beit3_model/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # BEiT-3 Weight and Sentencepiece models
2
+
3
+ 1. Please download [beit3_base_patch16_384_retrieval.pth](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/f30k_retrieval/beit3_base_patch16_384_f30k_retrieval.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) and [beit3.spm](https://conversationhub.blob.core.windows.net/beit-share-public/beit3/sentencepiece/beit3.spm?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D) model.
4
+
5
+ 2. Put those 2 model inside this folder
6
+
src/itr/dtb_cursor.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import faiss
3
+ import os
4
+
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+
8
+
9
+ class DatabaseCursor:
10
+ def __init__(self, index_file_path: str, keyframes_groups_json_path: str):
11
+ self._load_index(index_file_path)
12
+ self._load_keyframes_groups_info(keyframes_groups_json_path)
13
+
14
+ @lru_cache(maxsize=1)
15
+ def _load_index(self, index_file_path):
16
+ self.index = faiss.read_index(index_file_path)
17
+
18
+ @lru_cache(maxsize=1)
19
+ def _load_keyframes_groups_info(self, keyframes_groups_json_path: str):
20
+ with open(keyframes_groups_json_path) as file:
21
+ self.keyframes_group_info = json.loads(file.read())
22
+
23
+ def kNN_search(self, query_vector: str, topk: int = 10):
24
+ results = []
25
+ distances, ids = self.index.search(query_vector, topk)
26
+ for i in range(len(ids[0])):
27
+ frame_detail = self.keyframes_group_info[ids[0][i]]
28
+ frame_detail["distance"] = str(distances[0][i])
29
+ results.append(frame_detail)
30
+ return results
src/itr/modeling_finetune.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from . import utils
13
+ from .modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
14
+ from timm.models.registry import register_model
15
+
16
+
17
+ class TwoLayerMLP(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features,
21
+ hidden_features,
22
+ out_features,
23
+ norm_layer,
24
+ norm_input=True,
25
+ ):
26
+ super().__init__()
27
+ self.norm1 = norm_layer(in_features) if norm_input else nn.Identity()
28
+ self.dense1 = nn.Linear(in_features, hidden_features)
29
+ self.norm2 = norm_layer(hidden_features)
30
+ self.act = nn.GELU()
31
+ self.dense2 = nn.Linear(hidden_features, out_features)
32
+
33
+ def forward(self, x):
34
+ x = self.norm1(x)
35
+ x = self.dense1(x)
36
+ x = self.norm2(x)
37
+ x = self.act(x)
38
+ return self.dense2(x)
39
+
40
+
41
+ class Pooler(nn.Module):
42
+ def __init__(self, input_features, output_features, norm_layer):
43
+ super().__init__()
44
+ self.norm = norm_layer(input_features)
45
+ self.dense = nn.Linear(input_features, output_features)
46
+ self.activation = nn.Tanh()
47
+
48
+ def forward(self, x):
49
+ cls_rep = x[:, 0, :]
50
+ cls_rep = self.norm(cls_rep)
51
+ pooled_output = self.dense(cls_rep)
52
+ pooled_output = self.activation(pooled_output)
53
+ return pooled_output
54
+
55
+
56
+ class BEiT3ForVisualReasoning(BEiT3Wrapper):
57
+ def __init__(self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs):
58
+ super().__init__(args=args)
59
+ embed_dim = args.encoder_embed_dim
60
+ self.head = TwoLayerMLP(
61
+ in_features=embed_dim * 4,
62
+ hidden_features=embed_dim * 2,
63
+ out_features=num_classes,
64
+ norm_layer=norm_layer,
65
+ )
66
+ init_scale = 0.001
67
+ self.head.apply(self._init_weights)
68
+ if isinstance(self.head.dense1, nn.Linear):
69
+ self.head.dense1.weight.data.mul_(init_scale)
70
+ self.head.dense1.bias.data.mul_(init_scale)
71
+
72
+ if isinstance(self.head.dense2, nn.Linear):
73
+ self.head.dense2.weight.data.mul_(init_scale)
74
+ self.head.dense2.bias.data.mul_(init_scale)
75
+
76
+ def forward(self, image_a, image_b, text_description, padding_mask, **kwargs):
77
+ bsz, _ = text_description.size()
78
+
79
+ vision_input = torch.cat((image_a, image_b), dim=0)
80
+ language_input = torch.cat((text_description, text_description), dim=0)
81
+ padding_mask = torch.cat((padding_mask, padding_mask), dim=0)
82
+
83
+ outputs = self.beit3(
84
+ textual_tokens=language_input,
85
+ visual_tokens=vision_input,
86
+ text_padding_position=padding_mask,
87
+ )
88
+ x = outputs["encoder_out"]
89
+ multiway_split_position = outputs["multiway_split_position"]
90
+
91
+ vision_cls = x[:, 0, :]
92
+ language_cls = x[:, multiway_split_position, :]
93
+ cls_rep = torch.cat((vision_cls, language_cls), dim=-1)
94
+ a, b = torch.split(cls_rep, split_size_or_sections=[bsz, bsz], dim=0)
95
+ cls_rep = torch.cat((a, b), dim=-1)
96
+ return self.head(cls_rep)
97
+
98
+
99
+ class BEiT3ForImageClassification(BEiT3Wrapper):
100
+ def __init__(self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs):
101
+ super().__init__(args=args)
102
+ embed_dim = args.encoder_embed_dim
103
+ self.fc_norm = norm_layer(embed_dim)
104
+ self.head = (
105
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
106
+ )
107
+
108
+ self.fc_norm.apply(self._init_weights)
109
+ self.head.apply(self._init_weights)
110
+ init_scale = 0.001
111
+ if isinstance(self.head, nn.Linear):
112
+ self.head.weight.data.mul_(init_scale)
113
+ self.head.bias.data.mul_(init_scale)
114
+
115
+ def forward(self, image, **kwargs):
116
+ x = self.beit3(textual_tokens=None, visual_tokens=image)["encoder_out"]
117
+ t = x[:, 1:, :]
118
+ cls_x = self.fc_norm(t.mean(1))
119
+ return self.head(cls_x)
120
+
121
+
122
+ class BEiT3ForCaptioning(BEiT3Wrapper):
123
+ def __init__(self, args, **kwargs):
124
+ super().__init__(args=args)
125
+ embed_dim = args.encoder_embed_dim
126
+ self.mlm_head = nn.Linear(embed_dim, args.vocab_size)
127
+ self.mlm_head.apply(self._init_weights)
128
+
129
+ def forward(
130
+ self,
131
+ image,
132
+ text_ids,
133
+ padding_mask,
134
+ language_masked_pos,
135
+ text_len=None,
136
+ incremental_state=None,
137
+ **kwargs
138
+ ):
139
+ text_len = text_len if text_len is not None else text_ids.size(1)
140
+ image_len = self.beit3.vision_embed.num_position_embeddings()
141
+ max_len = text_len + image_len
142
+ uni_mask = torch.zeros(
143
+ (max_len, max_len), dtype=torch.long, device=text_ids.device
144
+ )
145
+ i_start, i_end = 0, image_len
146
+ t_start, t_end = image_len, max_len
147
+ # triangle mask for caption to caption
148
+ uni_mask[t_start:t_end, t_start:t_end] = torch.tril(
149
+ torch.ones(text_len, text_len, dtype=torch.long, device=text_ids.device)
150
+ )
151
+ # full attention for caption to image
152
+ uni_mask[t_start:t_end, i_start:i_end] = 1
153
+ # full attention for image to image
154
+ uni_mask[i_start:i_end, i_start:i_end] = 1
155
+ uni_mask = 1 - uni_mask
156
+
157
+ if incremental_state is not None:
158
+ for idx in range(self.get_num_layers()):
159
+ if idx not in incremental_state:
160
+ incremental_state[idx] = {}
161
+
162
+ # for incremental decoding
163
+ positions = None
164
+ if image is None:
165
+ uni_mask = uni_mask[-2:]
166
+ padding_mask = None
167
+ # start position (2 (fairseq starts at 2) + cur_position) is equal to text_len
168
+ positions = (
169
+ torch.arange(
170
+ text_len, text_ids.size(1) + text_len, device=text_ids.device
171
+ )
172
+ .long()
173
+ .unsqueeze(0)
174
+ )
175
+
176
+ outputs = self.beit3(
177
+ textual_tokens=text_ids,
178
+ visual_tokens=image,
179
+ text_padding_position=padding_mask,
180
+ attn_mask=uni_mask,
181
+ incremental_state=incremental_state,
182
+ positions=positions,
183
+ )
184
+ if image is not None:
185
+ text_feats = outputs["encoder_out"][:, image_len:]
186
+ else:
187
+ text_feats = outputs["encoder_out"]
188
+
189
+ if language_masked_pos is not None:
190
+ text_feats = text_feats[language_masked_pos.bool()]
191
+
192
+ return self.mlm_head(text_feats), incremental_state
193
+
194
+
195
+ class BEiT3ForVisualQuestionAnswering(BEiT3Wrapper):
196
+ def __init__(self, args, num_classes, norm_layer=nn.LayerNorm, **kwargs):
197
+ super().__init__(args=args)
198
+ embed_dim = args.encoder_embed_dim
199
+ self.pooler = Pooler(
200
+ input_features=embed_dim,
201
+ output_features=embed_dim,
202
+ norm_layer=norm_layer,
203
+ )
204
+ self.pooler.apply(self._init_weights)
205
+ self.head = nn.Sequential(
206
+ nn.Linear(embed_dim, embed_dim * 2),
207
+ norm_layer(embed_dim * 2),
208
+ nn.GELU(),
209
+ nn.Linear(embed_dim * 2, num_classes),
210
+ )
211
+ self.head.apply(self._init_weights)
212
+
213
+ def forward(self, image, question, padding_mask, **kwargs):
214
+ outputs = self.beit3(
215
+ textual_tokens=question,
216
+ visual_tokens=image,
217
+ text_padding_position=padding_mask,
218
+ )
219
+ x = outputs["encoder_out"]
220
+ cls_rep = self.pooler(x)
221
+ return self.head(cls_rep)
222
+
223
+
224
+ class BEiT3ForRetrieval(BEiT3Wrapper):
225
+ def __init__(self, args, **kwargs):
226
+ super().__init__(args=args)
227
+ embed_dim = args.encoder_embed_dim
228
+ self.language_head = nn.Linear(embed_dim, embed_dim, bias=False)
229
+ self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False)
230
+ self.language_head.apply(self._init_weights)
231
+ self.vision_head.apply(self._init_weights)
232
+ self.criterion = utils.ClipLoss(
233
+ rank=utils.get_rank(),
234
+ world_size=utils.get_world_size(),
235
+ )
236
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
237
+
238
+ def forward(
239
+ self,
240
+ image=None,
241
+ text_description=None,
242
+ padding_mask=None,
243
+ only_infer=False,
244
+ **kwargs
245
+ ):
246
+ if image is not None:
247
+ outputs = self.beit3(
248
+ textual_tokens=None,
249
+ visual_tokens=image,
250
+ text_padding_position=None,
251
+ )
252
+ x = outputs["encoder_out"]
253
+ vision_cls = self.vision_head(x[:, 0, :])
254
+ vision_cls = F.normalize(vision_cls, dim=-1)
255
+ else:
256
+ vision_cls = None
257
+
258
+ if text_description is not None:
259
+ outputs = self.beit3(
260
+ textual_tokens=text_description,
261
+ visual_tokens=None,
262
+ text_padding_position=padding_mask,
263
+ )
264
+ x = outputs["encoder_out"]
265
+ language_cls = self.language_head(x[:, 0, :])
266
+ language_cls = F.normalize(language_cls, dim=-1)
267
+ else:
268
+ language_cls = None
269
+
270
+ if only_infer:
271
+ return vision_cls, language_cls
272
+ else:
273
+ loss, logits_per_image, logits_per_text = self.criterion(
274
+ vision_cls, language_cls, self.logit_scale.exp()
275
+ )
276
+ return loss, vision_cls, language_cls
277
+
278
+
279
+ @register_model
280
+ def beit3_base_patch16_224_imageclassification(pretrained=False, **kwargs):
281
+ args = _get_base_config(**kwargs)
282
+ args.normalize_output = False
283
+ model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
284
+ return model
285
+
286
+
287
+ @register_model
288
+ def beit3_large_patch16_224_imageclassification(pretrained=False, **kwargs):
289
+ args = _get_large_config(**kwargs)
290
+ args.normalize_output = False
291
+ model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
292
+ return model
293
+
294
+
295
+ @register_model
296
+ def beit3_base_patch16_224_nlvr2(pretrained=False, **kwargs):
297
+ args = _get_base_config(**kwargs)
298
+ model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
299
+ return model
300
+
301
+
302
+ @register_model
303
+ def beit3_large_patch16_224_nlvr2(pretrained=False, **kwargs):
304
+ args = _get_large_config(**kwargs)
305
+ model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
306
+ return model
307
+
308
+
309
+ @register_model
310
+ def beit3_base_patch16_384_vqav2(pretrained=False, **kwargs):
311
+ args = _get_base_config(img_size=384, **kwargs)
312
+ args.normalize_output = False
313
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
314
+ return model
315
+
316
+
317
+ @register_model
318
+ def beit3_base_patch16_480_vqav2(pretrained=False, **kwargs):
319
+ args = _get_base_config(img_size=480, **kwargs)
320
+ args.normalize_output = False
321
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
322
+ return model
323
+
324
+
325
+ @register_model
326
+ def beit3_large_patch16_384_vqav2(pretrained=False, **kwargs):
327
+ args = _get_large_config(img_size=384, **kwargs)
328
+ args.normalize_output = False
329
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
330
+ return model
331
+
332
+
333
+ @register_model
334
+ def beit3_large_patch16_480_vqav2(pretrained=False, **kwargs):
335
+ args = _get_large_config(img_size=480, **kwargs)
336
+ args.normalize_output = False
337
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
338
+ return model
339
+
340
+
341
+ @register_model
342
+ def beit3_large_patch16_768_vqav2(pretrained=False, **kwargs):
343
+ args = _get_large_config(img_size=768, **kwargs)
344
+ args.normalize_output = False
345
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
346
+ return model
347
+
348
+
349
+ @register_model
350
+ def beit3_base_patch16_224_captioning(pretrained=False, **kwargs):
351
+ args = _get_base_config(**kwargs)
352
+ model = BEiT3ForCaptioning(args, **kwargs)
353
+ return model
354
+
355
+
356
+ @register_model
357
+ def beit3_base_patch16_480_captioning(pretrained=False, **kwargs):
358
+ args = _get_base_config(img_size=480, **kwargs)
359
+ model = BEiT3ForCaptioning(args, **kwargs)
360
+ return model
361
+
362
+
363
+ @register_model
364
+ def beit3_large_patch16_480_captioning(pretrained=False, **kwargs):
365
+ args = _get_large_config(img_size=480, **kwargs)
366
+ model = BEiT3ForCaptioning(args, **kwargs)
367
+ return model
368
+
369
+
370
+ @register_model
371
+ def beit3_base_patch16_224_retrieval(pretrained=False, **kwargs):
372
+ args = _get_base_config(**kwargs)
373
+ model = BEiT3ForRetrieval(args, **kwargs)
374
+ return model
375
+
376
+
377
+ @register_model
378
+ def beit3_base_patch16_384_retrieval(pretrained=False, **kwargs):
379
+ args = _get_base_config(img_size=384, **kwargs)
380
+ model = BEiT3ForRetrieval(args, **kwargs)
381
+ return model
382
+
383
+
384
+ @register_model
385
+ def beit3_large_patch16_384_retrieval(pretrained=False, **kwargs):
386
+ args = _get_large_config(img_size=384, **kwargs)
387
+ model = BEiT3ForRetrieval(args, **kwargs)
388
+ return model
src/itr/modeling_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
13
+ from torchscale.architecture.config import EncoderConfig
14
+ from torchscale.model.BEiT3 import BEiT3
15
+
16
+
17
+ def trunc_normal_(tensor, mean=0.0, std=1.0):
18
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19
+
20
+
21
+ def _get_base_config(
22
+ img_size=224,
23
+ patch_size=16,
24
+ drop_path_rate=0,
25
+ checkpoint_activations=None,
26
+ mlp_ratio=4,
27
+ vocab_size=64010,
28
+ **kwargs
29
+ ):
30
+ return EncoderConfig(
31
+ img_size=img_size,
32
+ patch_size=patch_size,
33
+ vocab_size=vocab_size,
34
+ multiway=True,
35
+ layernorm_embedding=False,
36
+ normalize_output=True,
37
+ no_output_layer=True,
38
+ drop_path_rate=drop_path_rate,
39
+ encoder_embed_dim=768,
40
+ encoder_attention_heads=12,
41
+ encoder_ffn_embed_dim=int(768 * mlp_ratio),
42
+ encoder_layers=12,
43
+ checkpoint_activations=checkpoint_activations,
44
+ )
45
+
46
+
47
+ def _get_large_config(
48
+ img_size=224,
49
+ patch_size=16,
50
+ drop_path_rate=0,
51
+ checkpoint_activations=None,
52
+ mlp_ratio=4,
53
+ vocab_size=64010,
54
+ **kwargs
55
+ ):
56
+ return EncoderConfig(
57
+ img_size=img_size,
58
+ patch_size=patch_size,
59
+ vocab_size=vocab_size,
60
+ multiway=True,
61
+ layernorm_embedding=False,
62
+ normalize_output=True,
63
+ no_output_layer=True,
64
+ drop_path_rate=drop_path_rate,
65
+ encoder_embed_dim=1024,
66
+ encoder_attention_heads=16,
67
+ encoder_ffn_embed_dim=int(1024 * mlp_ratio),
68
+ encoder_layers=24,
69
+ checkpoint_activations=checkpoint_activations,
70
+ )
71
+
72
+
73
+ class BEiT3Wrapper(nn.Module):
74
+ def __init__(self, args, **kwargs):
75
+ super().__init__()
76
+ self.args = args
77
+ self.beit3 = BEiT3(args)
78
+ self.apply(self._init_weights)
79
+
80
+ def fix_init_weight(self):
81
+ def rescale(param, layer_id):
82
+ param.div_(math.sqrt(2.0 * layer_id))
83
+
84
+ for layer_id, layer in enumerate(self.blocks):
85
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
86
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
87
+
88
+ def get_num_layers(self):
89
+ return self.beit3.encoder.num_layers
90
+
91
+ @torch.jit.ignore
92
+ def no_weight_decay(self):
93
+ return {
94
+ 'pos_embed',
95
+ 'cls_token',
96
+ 'beit3.encoder.embed_positions.A.weight',
97
+ 'beit3.vision_embed.cls_token',
98
+ 'logit_scale',
99
+ }
100
+
101
+ def _init_weights(self, m):
102
+ if isinstance(m, nn.Linear):
103
+ trunc_normal_(m.weight, std=0.02)
104
+ if isinstance(m, nn.Linear) and m.bias is not None:
105
+ nn.init.constant_(m.bias, 0)
106
+ elif isinstance(m, nn.LayerNorm):
107
+ nn.init.constant_(m.bias, 0)
108
+ nn.init.constant_(m.weight, 1.0)
src/itr/router.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .vlm_model import VisionLanguageModel
2
+ from .beit3_model import Beit3Model
3
+ from fastapi import APIRouter, File, status
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+
7
+ from .dtb_cursor import DatabaseCursor
8
+
9
+
10
+ class Item(BaseModel):
11
+ query_text: str
12
+ topk: int
13
+
14
+
15
+ router = APIRouter()
16
+
17
+ vectordb_cursor = None
18
+ vlm_model = None
19
+
20
+
21
+ def init_vectordb(**kargs):
22
+ # Singleton pattern
23
+ global vectordb_cursor
24
+ if vectordb_cursor is None:
25
+ vectordb_cursor = DatabaseCursor(**kargs)
26
+
27
+
28
+ def init_model(device: str):
29
+ # Singleton
30
+ global vlm_model
31
+ if vlm_model is None:
32
+ vlm_model = Beit3Model(device=device)
33
+
34
+
35
+ @router.post("/retrieval/image-text")
36
+ async def retrieve(item: Item) -> JSONResponse:
37
+ try:
38
+ query_vector = vlm_model.get_embedding(input=item.query_text)
39
+ search_results = vectordb_cursor.kNN_search(query_vector, item.topk)
40
+ except Exception:
41
+ return JSONResponse(
42
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
43
+ content={"message": "Search error"},
44
+ )
45
+
46
+ return JSONResponse(
47
+ status_code=status.HTTP_200_OK,
48
+ content={"message": "success", "details": search_results},
49
+ )
src/itr/utils.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import argparse
9
+ import datetime
10
+ import io
11
+ import json
12
+ import math
13
+ import os
14
+ import time
15
+ from collections import defaultdict, deque
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from timm.utils import get_state_dict
24
+ from torch import inf
25
+
26
+
27
+ def bool_flag(s):
28
+ """
29
+ Parse boolean arguments from the command line.
30
+ """
31
+ FALSY_STRINGS = {"off", "false", "0"}
32
+ TRUTHY_STRINGS = {"on", "true", "1"}
33
+ if s.lower() in FALSY_STRINGS:
34
+ return False
35
+ elif s.lower() in TRUTHY_STRINGS:
36
+ return True
37
+ else:
38
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
39
+
40
+
41
+ class SmoothedValue:
42
+ """Track a series of values and provide access to smoothed values over a
43
+ window or the global series average.
44
+ """
45
+
46
+ def __init__(self, window_size=20, fmt=None):
47
+ if fmt is None:
48
+ fmt = "{median:.4f} ({global_avg:.4f})"
49
+ self.deque = deque(maxlen=window_size)
50
+ self.total = 0.0
51
+ self.count = 0
52
+ self.fmt = fmt
53
+
54
+ def update(self, value, n=1):
55
+ self.deque.append(value)
56
+ self.count += n
57
+ self.total += value * n
58
+
59
+ def synchronize_between_processes(self):
60
+ """
61
+ Warning: does not synchronize the deque!
62
+ """
63
+ if not is_dist_avail_and_initialized():
64
+ return
65
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
66
+ dist.barrier()
67
+ dist.all_reduce(t)
68
+ t = t.tolist()
69
+ self.count = int(t[0])
70
+ self.total = t[1]
71
+
72
+ @property
73
+ def median(self):
74
+ d = torch.tensor(list(self.deque))
75
+ return d.median().item()
76
+
77
+ @property
78
+ def avg(self):
79
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
80
+ return d.mean().item()
81
+
82
+ @property
83
+ def global_avg(self):
84
+ return self.total / self.count
85
+
86
+ @property
87
+ def max(self):
88
+ return max(self.deque)
89
+
90
+ @property
91
+ def value(self):
92
+ return self.deque[-1]
93
+
94
+ def __str__(self):
95
+ return self.fmt.format(
96
+ median=self.median,
97
+ avg=self.avg,
98
+ global_avg=self.global_avg,
99
+ max=self.max,
100
+ value=self.value,
101
+ )
102
+
103
+
104
+ class MetricLogger:
105
+ def __init__(self, delimiter="\t"):
106
+ self.meters = defaultdict(SmoothedValue)
107
+ self.delimiter = delimiter
108
+
109
+ def update(self, **kwargs):
110
+ for k, v in kwargs.items():
111
+ if v is None:
112
+ continue
113
+ if isinstance(v, torch.Tensor):
114
+ v = v.item()
115
+ assert isinstance(v, (float, int))
116
+ self.meters[k].update(v)
117
+
118
+ def __getattr__(self, attr):
119
+ if attr in self.meters:
120
+ return self.meters[attr]
121
+ if attr in self.__dict__:
122
+ return self.__dict__[attr]
123
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
124
+
125
+ def __str__(self):
126
+ loss_str = []
127
+ for name, meter in self.meters.items():
128
+ loss_str.append(f"{name}: {str(meter)}")
129
+ return self.delimiter.join(loss_str)
130
+
131
+ def synchronize_between_processes(self):
132
+ for meter in self.meters.values():
133
+ meter.synchronize_between_processes()
134
+
135
+ def add_meter(self, name, meter):
136
+ self.meters[name] = meter
137
+
138
+ def log_every(self, iterable, print_freq, header=None):
139
+ i = 0
140
+ if not header:
141
+ header = ''
142
+ start_time = time.time()
143
+ end = time.time()
144
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
145
+ data_time = SmoothedValue(fmt='{avg:.4f}')
146
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
147
+ log_msg = [
148
+ header,
149
+ '[{0' + space_fmt + '}/{1}]',
150
+ 'eta: {eta}',
151
+ '{meters}',
152
+ 'time: {time}',
153
+ 'data: {data}',
154
+ ]
155
+ if torch.cuda.is_available():
156
+ log_msg.append('max mem: {memory:.0f}')
157
+ log_msg = self.delimiter.join(log_msg)
158
+ MB = 1024.0 * 1024.0
159
+ for obj in iterable:
160
+ data_time.update(time.time() - end)
161
+ yield obj
162
+ iter_time.update(time.time() - end)
163
+ if i % print_freq == 0 or i == len(iterable) - 1:
164
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
165
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
166
+ if torch.cuda.is_available():
167
+ print(
168
+ log_msg.format(
169
+ i,
170
+ len(iterable),
171
+ eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time),
174
+ data=str(data_time),
175
+ memory=torch.cuda.max_memory_allocated() / MB,
176
+ )
177
+ )
178
+ else:
179
+ print(
180
+ log_msg.format(
181
+ i,
182
+ len(iterable),
183
+ eta=eta_string,
184
+ meters=str(self),
185
+ time=str(iter_time),
186
+ data=str(data_time),
187
+ )
188
+ )
189
+ i += 1
190
+ end = time.time()
191
+ total_time = time.time() - start_time
192
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
193
+ print(
194
+ '{} Total time: {} ({:.4f} s / it)'.format(
195
+ header, total_time_str, total_time / len(iterable)
196
+ )
197
+ )
198
+
199
+
200
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
201
+ """
202
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
203
+ """
204
+ mem_file = io.BytesIO()
205
+ torch.save(checkpoint, mem_file)
206
+ mem_file.seek(0)
207
+ model_ema._load_checkpoint(mem_file)
208
+
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+
216
+ builtin_print = __builtin__.print
217
+
218
+ def print(*args, **kwargs):
219
+ force = kwargs.pop('force', False)
220
+ if is_master or force:
221
+ builtin_print(*args, **kwargs)
222
+
223
+ __builtin__.print = print
224
+
225
+
226
+ def is_dist_avail_and_initialized():
227
+ if not dist.is_available():
228
+ return False
229
+ if not dist.is_initialized():
230
+ return False
231
+ return True
232
+
233
+
234
+ def get_world_size():
235
+ if not is_dist_avail_and_initialized():
236
+ return 1
237
+ return dist.get_world_size()
238
+
239
+
240
+ def get_rank():
241
+ if not is_dist_avail_and_initialized():
242
+ return 0
243
+ return dist.get_rank()
244
+
245
+
246
+ def is_main_process():
247
+ return get_rank() == 0
248
+
249
+
250
+ def save_on_master(*args, **kwargs):
251
+ if is_main_process():
252
+ torch.save(*args, **kwargs)
253
+
254
+
255
+ def _get_rank_env():
256
+ if "RANK" in os.environ:
257
+ return int(os.environ["RANK"])
258
+ else:
259
+ return int(os.environ['OMPI_COMM_WORLD_RANK'])
260
+
261
+
262
+ def _get_local_rank_env():
263
+ if "LOCAL_RANK" in os.environ:
264
+ return int(os.environ["LOCAL_RANK"])
265
+ else:
266
+ return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
267
+
268
+
269
+ def _get_world_size_env():
270
+ if "WORLD_SIZE" in os.environ:
271
+ return int(os.environ["WORLD_SIZE"])
272
+ else:
273
+ return int(os.environ['OMPI_COMM_WORLD_SIZE'])
274
+
275
+
276
+ # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
277
+ def init_distributed_mode(args):
278
+ if args.dist_on_itp:
279
+ args.rank = _get_rank_env()
280
+ args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE'])
281
+ args.gpu = _get_local_rank_env()
282
+ args.dist_url = "tcp://{}:{}".format(os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
283
+ os.environ['LOCAL_RANK'] = str(args.gpu)
284
+ os.environ['RANK'] = str(args.rank)
285
+ os.environ['WORLD_SIZE'] = str(args.world_size)
286
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
287
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
288
+ args.rank = int(os.environ["RANK"])
289
+ args.world_size = int(os.environ['WORLD_SIZE'])
290
+ args.gpu = int(os.environ['LOCAL_RANK'])
291
+ elif 'SLURM_PROCID' in os.environ:
292
+ args.rank = int(os.environ['SLURM_PROCID'])
293
+ args.gpu = args.rank % torch.cuda.device_count()
294
+ else:
295
+ print('Not using distributed mode')
296
+ args.distributed = False
297
+ return
298
+
299
+ args.distributed = True
300
+
301
+ torch.cuda.set_device(args.gpu)
302
+ args.dist_backend = 'nccl'
303
+ print(
304
+ f'| distributed init (rank {args.rank}): {args.dist_url}, gpu {args.gpu}',
305
+ flush=True,
306
+ )
307
+ torch.distributed.init_process_group(
308
+ backend=args.dist_backend,
309
+ init_method=args.dist_url,
310
+ world_size=args.world_size,
311
+ rank=args.rank,
312
+ timeout=datetime.timedelta(0, 7200),
313
+ )
314
+ torch.distributed.barrier()
315
+ setup_for_distributed(args.rank == 0)
316
+
317
+
318
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
319
+ missing_keys = []
320
+ unexpected_keys = []
321
+ error_msgs = []
322
+ # copy state_dict so _load_from_state_dict can modify it
323
+ metadata = getattr(state_dict, '_metadata', None)
324
+ state_dict = state_dict.copy()
325
+ if metadata is not None:
326
+ state_dict._metadata = metadata
327
+
328
+ def load(module, prefix=''):
329
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
330
+ module._load_from_state_dict(
331
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
332
+ )
333
+ for name, child in module._modules.items():
334
+ if child is not None:
335
+ load(child, prefix + name + '.')
336
+
337
+ load(model, prefix=prefix)
338
+
339
+ warn_missing_keys = []
340
+ ignore_missing_keys = []
341
+ for key in missing_keys:
342
+ keep_flag = True
343
+ for ignore_key in ignore_missing.split('|'):
344
+ if ignore_key in key:
345
+ keep_flag = False
346
+ break
347
+ if keep_flag:
348
+ warn_missing_keys.append(key)
349
+ else:
350
+ ignore_missing_keys.append(key)
351
+
352
+ missing_keys = warn_missing_keys
353
+
354
+ if len(missing_keys) > 0:
355
+ print(
356
+ "Weights of {} not initialized from pretrained model: {}".format(
357
+ model.__class__.__name__, missing_keys
358
+ )
359
+ )
360
+ if len(unexpected_keys) > 0:
361
+ print(
362
+ "Weights from pretrained model not used in {}: {}".format(
363
+ model.__class__.__name__, unexpected_keys
364
+ )
365
+ )
366
+ if len(ignore_missing_keys) > 0:
367
+ print(
368
+ "Ignored weights of {} not initialized from pretrained model: {}".format(
369
+ model.__class__.__name__, ignore_missing_keys
370
+ )
371
+ )
372
+ if len(error_msgs) > 0:
373
+ print('\n'.join(error_msgs))
374
+
375
+
376
+ class NativeScalerWithGradNormCount:
377
+ state_dict_key = "amp_scaler"
378
+
379
+ def __init__(self):
380
+ self._scaler = torch.cuda.amp.GradScaler()
381
+
382
+ def __call__(
383
+ self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True
384
+ ):
385
+ self._scaler.scale(loss).backward(create_graph=create_graph)
386
+ if update_grad:
387
+ if clip_grad is not None:
388
+ assert parameters is not None
389
+ self._scaler.unscale_(
390
+ optimizer
391
+ ) # unscale the gradients of optimizer's assigned params in-place
392
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
393
+ else:
394
+ self._scaler.unscale_(optimizer)
395
+ norm = get_grad_norm_(parameters)
396
+ self._scaler.step(optimizer)
397
+ self._scaler.update()
398
+ else:
399
+ norm = None
400
+ return norm
401
+
402
+ def state_dict(self):
403
+ return self._scaler.state_dict()
404
+
405
+ def load_state_dict(self, state_dict):
406
+ self._scaler.load_state_dict(state_dict)
407
+
408
+
409
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
410
+ if isinstance(parameters, torch.Tensor):
411
+ parameters = [parameters]
412
+ parameters = [p for p in parameters if p.grad is not None]
413
+ norm_type = float(norm_type)
414
+ if len(parameters) == 0:
415
+ return torch.tensor(0.0)
416
+ device = parameters[0].grad.device
417
+ if norm_type == inf:
418
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
419
+ else:
420
+ total_norm = torch.norm(
421
+ torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
422
+ norm_type,
423
+ )
424
+ return total_norm
425
+
426
+
427
+ def cosine_scheduler(
428
+ base_value,
429
+ final_value,
430
+ epochs,
431
+ niter_per_ep,
432
+ warmup_epochs=0,
433
+ start_warmup_value=0,
434
+ warmup_steps=-1,
435
+ sched_type="cos",
436
+ ):
437
+ warmup_schedule = np.array([])
438
+ warmup_iters = warmup_epochs * niter_per_ep
439
+ if warmup_steps > 0:
440
+ warmup_iters = warmup_steps
441
+ print("Set warmup steps = %d" % warmup_iters)
442
+ if warmup_epochs > 0:
443
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
444
+
445
+ if sched_type == "cos":
446
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
447
+ schedule = np.array(
448
+ [
449
+ final_value
450
+ + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters))))
451
+ for i in iters
452
+ ]
453
+ )
454
+ elif sched_type == "linear":
455
+ schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters)
456
+ else:
457
+ raise NotImplementedError()
458
+
459
+ schedule = np.concatenate((warmup_schedule, schedule))
460
+
461
+ assert len(schedule) == epochs * niter_per_ep
462
+ return schedule
463
+
464
+
465
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
466
+ output_dir = Path(args.output_dir)
467
+ if loss_scaler is not None:
468
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch)]
469
+ for checkpoint_path in checkpoint_paths:
470
+ to_save = {
471
+ 'model': model_without_ddp.state_dict(),
472
+ 'optimizer': optimizer.state_dict(),
473
+ 'epoch': epoch,
474
+ 'scaler': loss_scaler.state_dict(),
475
+ 'args': args,
476
+ }
477
+
478
+ if model_ema is not None:
479
+ to_save['model_ema'] = get_state_dict(model_ema)
480
+
481
+ save_on_master(to_save, checkpoint_path)
482
+ else:
483
+ client_state = {'epoch': epoch, "args": args}
484
+ if model_ema is not None:
485
+ client_state['model_ema'] = get_state_dict(model_ema)
486
+ model.save_checkpoint(
487
+ save_dir=args.output_dir, tag="checkpoint-%s" % epoch, client_state=client_state
488
+ )
489
+
490
+
491
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
492
+ output_dir = Path(args.output_dir)
493
+ if loss_scaler is not None:
494
+ # torch.amp
495
+ if args.auto_resume and len(args.resume) == 0:
496
+ import glob
497
+
498
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
499
+ latest_ckpt = -1
500
+ for ckpt in all_checkpoints:
501
+ t = ckpt.split('-')[-1].split('.')[0]
502
+ if t.isdigit():
503
+ latest_ckpt = max(int(t), latest_ckpt)
504
+ if latest_ckpt >= 0:
505
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
506
+ print("Auto resume checkpoint: %s" % args.resume)
507
+
508
+ if args.resume:
509
+ if args.resume.startswith('https'):
510
+ checkpoint = torch.hub.load_state_dict_from_url(
511
+ args.resume, map_location='cpu', check_hash=True
512
+ )
513
+ else:
514
+ checkpoint = torch.load(args.resume, map_location='cpu')
515
+ model_without_ddp.load_state_dict(checkpoint['model'])
516
+ print("Resume checkpoint %s" % args.resume)
517
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
518
+ optimizer.load_state_dict(checkpoint['optimizer'])
519
+ args.start_epoch = checkpoint['epoch'] + 1
520
+ if hasattr(args, 'model_ema') and args.model_ema:
521
+ _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
522
+ if 'scaler' in checkpoint:
523
+ loss_scaler.load_state_dict(checkpoint['scaler'])
524
+ print("With optim & sched!")
525
+ else:
526
+ # deepspeed, only support '--auto_resume'.
527
+ if args.auto_resume:
528
+ import glob
529
+
530
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
531
+ latest_ckpt = -1
532
+ for ckpt in all_checkpoints:
533
+ t = ckpt.split('-')[-1].split('.')[0]
534
+ if t.isdigit():
535
+ latest_ckpt = max(int(t), latest_ckpt)
536
+ if latest_ckpt >= 0:
537
+ args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
538
+ print("Auto resume checkpoint: %d" % latest_ckpt)
539
+ _, client_states = model.load_checkpoint(
540
+ args.output_dir, tag='checkpoint-%d' % latest_ckpt
541
+ )
542
+ args.start_epoch = client_states['epoch'] + 1
543
+ if model_ema is not None:
544
+ if args.model_ema:
545
+ _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
546
+
547
+
548
+ # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
549
+ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
550
+ if ckpt_path.startswith('https'):
551
+ checkpoint = torch.hub.load_state_dict_from_url(
552
+ ckpt_path, map_location='cpu', check_hash=True
553
+ )
554
+ else:
555
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
556
+
557
+ print("Load ckpt from %s" % ckpt_path)
558
+ checkpoint_model = None
559
+ for model_key in model_key.split('|'):
560
+ if model_key in checkpoint:
561
+ checkpoint_model = checkpoint[model_key]
562
+ print("Load state_dict by model_key = %s" % model_key)
563
+ break
564
+
565
+ if checkpoint_model is None:
566
+ checkpoint_model = checkpoint
567
+
568
+ state_dict = model.state_dict()
569
+ for k in ['head.weight', 'head.bias']:
570
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
571
+ print(f"Removing key {k} from pretrained checkpoint")
572
+ del checkpoint_model[k]
573
+
574
+ # interpolate position embedding
575
+ for pos_embed_key in (
576
+ "vision_pos_embed",
577
+ "pos_embed",
578
+ "beit3.encoder.embed_positions.A.weight",
579
+ ):
580
+ if pos_embed_key in checkpoint_model:
581
+ pos_embed_checkpoint = checkpoint_model[pos_embed_key]
582
+ embedding_size = pos_embed_checkpoint.shape[-1]
583
+ if pos_embed_key == "beit3.encoder.embed_positions.A.weight":
584
+ # being consistent with Fairseq, which starts from 2 for position embedding
585
+ torchscale_model = True
586
+ num_patches = model.beit3.vision_embed.num_patches
587
+ num_extra_tokens = (
588
+ model.beit3.vision_embed.num_position_embeddings() + 2 - num_patches
589
+ )
590
+ else:
591
+ torchscale_model = False
592
+ num_patches = model.patch_embed.num_patches
593
+ num_extra_tokens = getattr(model, pos_embed_key).shape[-2] - num_patches
594
+ # height (== width) for the checkpoint position embedding
595
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
596
+ # height (== width) for the new position embedding
597
+ new_size = int(num_patches**0.5)
598
+ # class_token and dist_token are kept unchanged
599
+ if orig_size != new_size:
600
+ print(
601
+ "Position interpolate from %dx%d to %dx%d"
602
+ % (orig_size, orig_size, new_size, new_size)
603
+ )
604
+ if torchscale_model:
605
+ extra_tokens = pos_embed_checkpoint[:num_extra_tokens].unsqueeze(0)
606
+ # only the position tokens are interpolated
607
+ pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
608
+ else:
609
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
610
+ # only the position tokens are interpolated
611
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
612
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(
613
+ 0, 3, 1, 2
614
+ )
615
+ pos_tokens = torch.nn.functional.interpolate(
616
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
617
+ )
618
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
619
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
620
+ if torchscale_model:
621
+ new_pos_embed = new_pos_embed.squeeze(0)
622
+ checkpoint_model[pos_embed_key] = new_pos_embed
623
+
624
+ load_state_dict(model, checkpoint_model, prefix=model_prefix)
625
+
626
+
627
+ def create_ds_config(args):
628
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
629
+ with open(args.deepspeed_config, mode="w") as writer:
630
+ ds_config = {
631
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
632
+ "train_micro_batch_size_per_gpu": args.batch_size,
633
+ "steps_per_print": 1000,
634
+ "optimizer": {
635
+ "type": "Adam",
636
+ "adam_w_mode": True,
637
+ "params": {
638
+ "lr": args.lr,
639
+ "weight_decay": args.weight_decay,
640
+ "bias_correction": True,
641
+ "betas": [args.opt_betas[0], args.opt_betas[1]],
642
+ "eps": args.opt_eps,
643
+ },
644
+ },
645
+ "fp16": {
646
+ "enabled": True,
647
+ "loss_scale": 0,
648
+ "initial_scale_power": getattr(args, "initial_scale_power", 12),
649
+ "loss_scale_window": 1000,
650
+ "hysteresis": 2,
651
+ "min_loss_scale": 1,
652
+ },
653
+ "amp": {"enabled": False, "opt_level": "O2"},
654
+ }
655
+
656
+ if args.clip_grad is not None:
657
+ ds_config.update({'gradient_clipping': args.clip_grad})
658
+
659
+ if args.zero_stage == 1:
660
+ ds_config.update(
661
+ {"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}}
662
+ )
663
+ elif args.zero_stage > 1:
664
+ raise NotImplementedError()
665
+
666
+ writer.write(json.dumps(ds_config, indent=2))
667
+
668
+
669
+ def merge_batch_tensors_by_dict_key(batch):
670
+ batch_tensors = {}
671
+ for tensor_key in batch[0]:
672
+ if isinstance(batch[0][tensor_key], torch.Tensor):
673
+ batch_tensors[tensor_key] = torch.stack([d[tensor_key] for d in batch])
674
+ else:
675
+ batch_tensors[tensor_key] = torch.tensor(
676
+ [d[tensor_key] for d in batch], dtype=torch.long
677
+ )
678
+ return batch_tensors
679
+
680
+
681
+ def get_loss_scale_for_deepspeed(model):
682
+ optimizer = model.optimizer
683
+ loss_scale = None
684
+ if hasattr(optimizer, 'loss_scale'):
685
+ loss_scale = optimizer.loss_scale
686
+ elif hasattr(optimizer, 'cur_scale'):
687
+ loss_scale = optimizer.cur_scale
688
+ return loss_scale
689
+
690
+
691
+ class GatherLayer(torch.autograd.Function):
692
+ """
693
+ Gather tensors from all workers with support for backward propagation:
694
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
695
+ """
696
+
697
+ @staticmethod
698
+ def forward(ctx, x):
699
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
700
+ dist.all_gather(output, x)
701
+ return tuple(output)
702
+
703
+ @staticmethod
704
+ def backward(ctx, *grads):
705
+ all_gradients = torch.stack(grads)
706
+ dist.all_reduce(all_gradients)
707
+ return all_gradients[dist.get_rank()]
708
+
709
+
710
+ def gather_features(
711
+ image_features,
712
+ text_features,
713
+ ):
714
+ gathered_image_features = GatherLayer.apply(image_features)
715
+ gathered_text_features = GatherLayer.apply(text_features)
716
+ all_image_features = torch.cat(gathered_image_features)
717
+ all_text_features = torch.cat(gathered_text_features)
718
+
719
+ return all_image_features, all_text_features
720
+
721
+
722
+ # The implementation code is modified from open_clip (https://github.com/mlfoundations/open_clip.git)
723
+ class ClipLoss(nn.Module):
724
+ def __init__(
725
+ self,
726
+ cache_labels=False,
727
+ rank=0,
728
+ world_size=1,
729
+ ):
730
+ super().__init__()
731
+ self.cache_labels = cache_labels
732
+ self.rank = rank
733
+ self.world_size = world_size
734
+
735
+ # cache state
736
+ self.prev_num_logits = 0
737
+ self.labels = {}
738
+
739
+ def forward(self, image_features, text_features, logit_scale):
740
+ device = image_features.device
741
+ if self.world_size > 1:
742
+ all_image_features, all_text_features = gather_features(image_features, text_features)
743
+
744
+ logits_per_image = logit_scale * image_features @ all_text_features.T
745
+ logits_per_text = logit_scale * text_features @ all_image_features.T
746
+ else:
747
+ logits_per_image = logit_scale * image_features @ text_features.T
748
+ logits_per_text = logit_scale * text_features @ image_features.T
749
+
750
+ # calculated ground-truth and cache if enabled
751
+ num_logits = logits_per_image.shape[0]
752
+ if self.prev_num_logits != num_logits or device not in self.labels:
753
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
754
+ if self.world_size > 1:
755
+ labels = labels + num_logits * self.rank
756
+ if self.cache_labels:
757
+ self.labels[device] = labels
758
+ self.prev_num_logits = num_logits
759
+ else:
760
+ labels = self.labels[device]
761
+
762
+ total_loss = (
763
+ F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)
764
+ ) / 2
765
+ return total_loss, logits_per_image, logits_per_text
766
+
767
+
768
+ def write_result_to_jsonl(test_stats, result_file):
769
+ with open(result_file, mode="w", encoding="utf-8") as writer:
770
+ writer.write(json.dumps(test_stats, indent=None))
771
+
772
+
773
+ def read_result_from_jsonl(result_file):
774
+ with open(result_file, encoding="utf-8") as reader:
775
+ return json.load(reader)
776
+
777
+
778
+ class BertCaptioningLoss(nn.Module):
779
+ def __init__(self, label_smoothing, drop_worst_ratio, drop_worst_after):
780
+ super().__init__()
781
+ self.label_smoothing = label_smoothing
782
+ self.drop_worst_ratio = drop_worst_ratio
783
+ self.drop_worst_after = drop_worst_after
784
+ self.log_soft = nn.LogSoftmax(dim=1)
785
+ self.kl = nn.KLDivLoss(reduction='none')
786
+ self.iter = 0
787
+
788
+ def forward(self, logits, target, iter):
789
+ eps = self.label_smoothing
790
+ n_class = logits.size(1)
791
+ one_hot = torch.zeros_like(logits).scatter(1, target.view(-1, 1), 1)
792
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
793
+ log_prb = self.log_soft(logits)
794
+ loss = self.kl(log_prb, one_hot).sum(1)
795
+
796
+ if self.drop_worst_ratio > 0 and iter > self.drop_worst_after:
797
+ loss, _ = torch.topk(
798
+ loss, k=int(loss.shape[0] * (1 - self.drop_worst_ratio)), largest=False
799
+ )
800
+ loss = loss.mean()
801
+
802
+ return loss
803
+
804
+
805
+ class BeamHypotheses:
806
+ def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
807
+ """
808
+ Initialize n-best list of hypotheses.
809
+ """
810
+ self.max_length = max_length - 1 # ignoring bos_token
811
+ self.length_penalty = length_penalty
812
+ self.early_stopping = early_stopping
813
+ self.n_hyp = n_hyp
814
+ self.hyp = []
815
+ self.worst_score = 1e9
816
+
817
+ def __len__(self):
818
+ """
819
+ Number of hypotheses in the list.
820
+ """
821
+ return len(self.hyp)
822
+
823
+ def add(self, hyp, sum_logprobs):
824
+ """
825
+ Add a new hypothesis to the list.
826
+ """
827
+ score = sum_logprobs / len(hyp) ** self.length_penalty
828
+ if len(self) < self.n_hyp or score > self.worst_score:
829
+ self.hyp.append((score, hyp))
830
+ if len(self) > self.n_hyp:
831
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
832
+ del self.hyp[sorted_scores[0][1]]
833
+ self.worst_score = sorted_scores[1][0]
834
+ else:
835
+ self.worst_score = min(score, self.worst_score)
836
+
837
+ def is_done(self, best_sum_logprobs):
838
+ """
839
+ If there are enough hypotheses and that none of the hypotheses being generated
840
+ can become better than the worst one in the heap, then we are done with this sentence.
841
+ """
842
+ if len(self) < self.n_hyp:
843
+ return False
844
+ elif self.early_stopping:
845
+ return True
846
+ else:
847
+ return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty
848
+
849
+
850
+ def dump_predictions(args, result, file_suffix):
851
+ global_rank = get_rank()
852
+ jsons = None
853
+ if global_rank >= 0:
854
+ output_file = os.path.join(args.task_cache_path, f"submit_{global_rank}_{file_suffix}.json")
855
+ with open(output_file, "w") as fp:
856
+ json.dump(result, fp, indent=2)
857
+ torch.distributed.barrier()
858
+
859
+ if global_rank == 0:
860
+ world_size = get_world_size()
861
+ jsons = []
862
+ for i in range(world_size):
863
+ each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json")
864
+ with open(each_file) as fp:
865
+ jsons += json.load(fp)
866
+
867
+ new_jsons = []
868
+ res_dict = dict()
869
+ if args.task in ["coco_captioning", "nocaps"]:
870
+ qid_key = "image_id"
871
+ else:
872
+ # for VQAv2
873
+ qid_key = "question_id"
874
+ for item in jsons:
875
+ if item[qid_key] in res_dict:
876
+ continue
877
+ new_jsons.append(item)
878
+ res_dict[item[qid_key]] = item
879
+ jsons = new_jsons
880
+
881
+ torch.distributed.barrier()
882
+ os.remove(output_file)
883
+ else:
884
+ jsons = result
885
+
886
+ result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json")
887
+ if jsons is not None:
888
+ with open(result_file, "w") as fp:
889
+ json.dump(jsons, fp, indent=2)
890
+ print("Infer %d examples into %s" % (len(jsons), result_file))
891
+ return result_file
src/itr/vlm_model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Union
3
+
4
+ import clip
5
+ from PIL import Image
6
+
7
+
8
+ class VisionLanguageModel:
9
+ def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
10
+ self._load_model(model_name, device)
11
+ self.device = device
12
+
13
+ @lru_cache(maxsize=1)
14
+ def _load_model(self, model_name, device: str = "cpu"):
15
+ self.model, self.processor = clip.load(model_name, device=device)
16
+
17
+ def get_embedding(self, input: Union[str, Image.Image]):
18
+ if isinstance(input, str):
19
+ tokens = clip.tokenize(input).to(self.device)
20
+ vector = self.model.encode_text(tokens)
21
+ vector /= vector.norm(dim=-1, keepdim=True)
22
+ vector = vector.cpu().detach().numpy().astype("float32")
23
+ return vector
24
+ elif isinstance(input, Image.Image):
25
+ image_input = self.preprocess(input).unsqueeze(0).to(self.device)
26
+ vector = self.model.encode_image(image_input)
27
+ vector /= vector.norm(dim=-1, keepdim=True)
28
+ vector = vector.cpu().detach().numpy().astype("float32")
29
+ return vector
30
+ else:
31
+ raise Exception("Invalid input type")
src/main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from config import settings
4
+ from fastapi import FastAPI, Request, status
5
+ from fastapi.exceptions import RequestValidationError
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse, RedirectResponse
8
+ from itr.router import init_model, init_vectordb
9
+ from itr.router import router as router
10
+ from pathlib import Path
11
+
12
+ app = FastAPI(title="[BeiT-3] Text-to-image Retrieval API")
13
+
14
+ SERVICE_ROOT = Path(__file__).parent.parent
15
+
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=settings.CORS_ORIGINS,
20
+ allow_headers=settings.CORS_HEADERS,
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ )
24
+
25
+
26
+ @app.exception_handler(RequestValidationError)
27
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
28
+ # Get the original 'detail' list of errors
29
+ details = exc.errors()
30
+ error_details = []
31
+
32
+ for error in details:
33
+ error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
34
+ return JSONResponse(content={"message": error_details})
35
+
36
+
37
+ @app.on_event("startup")
38
+ async def startup_event():
39
+ init_vectordb(
40
+ index_file_path=os.path.join(SERVICE_ROOT, settings.INDEX_FILE_PATH),
41
+ keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
42
+ )
43
+ device = (
44
+ "cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
45
+ )
46
+ init_model(device=device)
47
+
48
+
49
+ @app.get("/", include_in_schema=False)
50
+ async def root() -> None:
51
+ return RedirectResponse("/docs")
52
+
53
+
54
+ @app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
55
+ async def perform_healthcheck() -> None:
56
+ return JSONResponse(content={"message": "success"})
57
+
58
+
59
+ app.include_router(router)
60
+
61
+
62
+ # Start API
63
+ if __name__ == "__main__":
64
+ print(os.listdir(os.path.join(SERVICE_ROOT, "data/faiss-index/")))
65
+ import uvicorn
66
+
67
+ uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)