ngxquang commited on
Commit
52b1203
·
1 Parent(s): afe0d05

feat: add subframes data for clip api

Browse files
.env ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PROJECT INFORMATION
2
+ HOST=0.0.0.0
3
+ PORT=7860
4
+ CORS_HEADERS=["*"]
5
+ CORS_ORIGINS=["*"]
6
+
7
+ MODEL_NAME="ViT-B/32"
8
+ DEVICE="cpu" # ["cuda", "cpu"]
9
+
10
+ INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L36.faiss"
11
+ INDEX_SUBFRAMES_FILE_PATH="data/faiss-index/index_clip_subframes_L01_to_L36.faiss"
12
+ KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L36.json"
13
+ SUBFRAMES_GROUPS_JSON_PATH="data/config/subframes_groups_L01_to_L36.json"
.env.example ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PROJECT INFORMATION
2
+ HOST=0.0.0.0
3
+ PORT=7860
4
+ CORS_HEADERS=["*"]
5
+ CORS_ORIGINS=["*"]
6
+
7
+ MODEL_NAME="ViT-B/32"
8
+ DEVICE="cpu" # ["cuda", "cpu"]
9
+
10
+ INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L36.faiss"
11
+ INDEX_SUBFRAMES_FILE_PATH="data/faiss-index/index_clip_subframes_L01_to_L36.faiss"
12
+ KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L36.json"
13
+ SUBFRAMES_GROUPS_JSON_PATH="data/config/subframes_groups_L01_to_L36.json"
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
37
+ *.faiss filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .venv
124
+ env/
125
+ venv/
126
+ ENV/
127
+ env.bak/
128
+ venv.bak/
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # mkdocs documentation
138
+ /site
139
+
140
+ # mypy
141
+ .mypy_cache/
142
+ .dmypy.json
143
+ dmypy.json
144
+
145
+ # Pyre type checker
146
+ .pyre/
147
+
148
+ # pytype static type analyzer
149
+ .pytype/
150
+
151
+ # Cython debug symbols
152
+ cython_debug/
153
+
154
+ # Model Checkpoitns
155
+ *.pth
156
+
157
+ #Sentencepiece Tokenizer
158
+ *.spm
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ *.zip
168
+ *.xlsx
169
+ /convert/submission
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+
3
+ RUN apt-get update && \
4
+ apt-get install git gsutil -y && \
5
+ apt clean && \
6
+ rm -rf /var/cache/apt/*
7
+
8
+ WORKDIR /code
9
+
10
+ COPY requirements.txt /code/requirements.txt
11
+
12
+ # PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
13
+ # PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
14
+ # PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
15
+ ENV PYTHONDONTWRITEBYTECODE=1 \
16
+ PYTHONUNBUFFERED=1 \
17
+ PYTHONIOENCODING=utf-8
18
+
19
+ RUN pip install -U pip && \
20
+ pip install --no-cache-dir -r /code/requirements.txt
21
+
22
+ RUN useradd -m -u 1000 user
23
+
24
+ USER user
25
+
26
+ ENV HOME=/home/user \
27
+ PATH=/home/user/.local/bin:$PATH
28
+
29
+ WORKDIR $HOME/app
30
+
31
+ COPY --chown=user . $HOME/app
32
+
33
+ CMD ["python", "./src/main.py"]
data/config/keyframes_groups_L01_to_L36.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a28d33542216ad24cb09db5f4fd1040c0c2045bcd42d8a4f5e1d038deac73db4
3
+ size 29038197
data/config/subframes_groups_L01_to_L36.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:980472aaad434482a2e89d5a8bc076a923b41c26437b597ceb6c7de34bc4f9c7
3
+ size 28967171
data/faiss-index/index_clip_L01_to_L36.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:614c04492f8bb40dab35d5317c1ee52b5a2fee78e92b2cc5bf71386817f63172
3
+ size 674996269
data/faiss-index/index_clip_subframes_L01_to_L36.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4542bbf7f47179b00b0a4dc7f577245490c86b0399235c39051b12eaafc2efa
3
+ size 671422509
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.103.1
2
+ uvicorn==0.23.2
3
+ pydantic-settings==2.0.3
4
+
5
+
6
+ # Models
7
+ torch==1.7.1
8
+ torchvision==0.8.2
9
+ ftfy==6.1.1
10
+ regex
11
+ tqdm==4.66.1
12
+ git+https://github.com/openai/CLIP.git@main
13
+
14
+ # Vector Database
15
+ faiss-cpu
16
+
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from pydantic_settings import BaseSettings
4
+
5
+ FILE = Path(__file__)
6
+ ROOT = FILE.parent.parent
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ # API SETTINGS
11
+ HOST: str
12
+ PORT: int
13
+ CORS_ORIGINS: list
14
+ CORS_HEADERS: list
15
+
16
+ # MODEL SETTINGS
17
+ MODEL_NAME: str = "ViT-B/32"
18
+ DEVICE: str = "cpu"
19
+
20
+ # FAISS DATABASE SETTINGS
21
+ INDEX_FILE_PATH: str
22
+ INDEX_SUBFRAMES_FILE_PATH: str
23
+ KEYFRAMES_GROUPS_JSON_PATH: str
24
+ SUBFRAMES_GROUPS_JSON_PATH: str
25
+
26
+ class Config:
27
+ env_file = ROOT / ".env"
28
+
29
+
30
+ settings = Settings()
src/itr/__init__.py ADDED
File without changes
src/itr/dtb_cursor.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import faiss
3
+ import os
4
+
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+
8
+
9
+ class DatabaseCursor:
10
+ def __init__(
11
+ self,
12
+ index_file_path: str,
13
+ index_subframes_file_path: str,
14
+ keyframes_groups_json_path: str,
15
+ subframes_groups_json_path: str,
16
+ ):
17
+ self._load_index(index_file_path, index_subframes_file_path)
18
+ self._load_keyframes_groups_info(
19
+ keyframes_groups_json_path, subframes_groups_json_path
20
+ )
21
+
22
+ @lru_cache(maxsize=1)
23
+ def _load_index(self, index_file_path, index_subframes_file_path):
24
+ self.index = faiss.read_index(index_file_path)
25
+ index_subframes = faiss.read_index(index_subframes_file_path)
26
+ try:
27
+ self.index.merge_from(index_subframes)
28
+ except:
29
+ raise Exception("dtb_cursor::cannot merge keyframes and subframes index")
30
+
31
+ @lru_cache(maxsize=1)
32
+ def _load_keyframes_groups_info(
33
+ self, keyframes_groups_json_path: str, subframes_groups_json_path: str
34
+ ):
35
+ with open(keyframes_groups_json_path) as file:
36
+ keyframes_group_info = json.loads(file.read())
37
+ self.no_keyframes = len(keyframes_group_info)
38
+ with open(subframes_groups_json_path) as file:
39
+ subframes_groups_info = json.loads(file.read())
40
+ self.no_subframes = len(subframes_groups_info)
41
+
42
+ self.frames_groups_info = keyframes_group_info
43
+ self.frames_groups_info.extend(subframes_groups_info)
44
+ print(self.index.ntotal)
45
+ assert self.index.ntotal == len(
46
+ self.frames_groups_info
47
+ ), "dtb_cursor::Index length and map lenght mismatch"
48
+
49
+ def kNN_search(self, query_vector: str, topk: int = 10):
50
+ results = []
51
+ distances, ids = self.index.search(query_vector, topk)
52
+ for i in range(len(ids[0])):
53
+ frame_detail = self.frames_groups_info[ids[0][i]]
54
+ frame_detail["distance"] = str(distances[0][i])
55
+ frame_detail["folder"] = (
56
+ "Keyframes" if ids[0][i] < self.no_keyframes else "Subframes"
57
+ )
58
+ results.append(frame_detail)
59
+ return results
src/itr/router.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, status
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+
5
+ from .dtb_cursor import DatabaseCursor
6
+ from .vlm_model import VisionLanguageModel
7
+
8
+
9
+ class Item(BaseModel):
10
+ query_text: str
11
+ topk: int
12
+
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ vectordb_cursor = None
18
+ vlm_model = None
19
+
20
+
21
+ def init_vectordb(**kargs):
22
+ # Singleton pattern
23
+ global vectordb_cursor
24
+ if vectordb_cursor is None:
25
+ vectordb_cursor = DatabaseCursor(**kargs)
26
+
27
+
28
+ def init_model(**kargs):
29
+ # Singleton
30
+ global vlm_model
31
+ if vlm_model is None:
32
+ vlm_model = VisionLanguageModel(**kargs)
33
+
34
+
35
+ @router.post("/retrieval")
36
+ async def retrieve(item: Item) -> JSONResponse:
37
+ try:
38
+ query_vector = vlm_model.get_embedding(input=item.query_text)
39
+ search_results = vectordb_cursor.kNN_search(query_vector, item.topk)
40
+ except Exception:
41
+ return JSONResponse(
42
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
43
+ content={"message": "Search error"},
44
+ )
45
+
46
+ return JSONResponse(
47
+ status_code=status.HTTP_200_OK,
48
+ content={"message": "success", "details": search_results},
49
+ )
src/itr/vlm_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Union
3
+
4
+ import clip
5
+ from PIL import Image
6
+
7
+
8
+ class VisionLanguageModel:
9
+ def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
10
+ self._load_model(model_name, device)
11
+ self.device = device
12
+
13
+ @lru_cache(maxsize=1)
14
+ def _load_model(self, model_name, device: str = "cpu"):
15
+ self.model, self.processor = clip.load(model_name, device=device)
16
+
17
+ def get_embedding(self, input: Union[str, Image.Image]):
18
+ if isinstance(input, str):
19
+ tokens = clip.tokenize(input).to(self.device)
20
+ vector = self.model.encode_text(tokens)
21
+ vector /= vector.norm(dim=-1, keepdim=True)
22
+ vector = vector.cpu().detach().numpy().astype("float32")
23
+ return vector
24
+ elif isinstance(input, Image.Image):
25
+ image_input = self.preprocess(input).unsqueeze(0).to(self.device)
26
+ vector = self.model.encode_image(image_input)
27
+ vector /= vector.norm(dim=-1, keepdim=True)
28
+ return vector
29
+ else:
30
+ raise Exception("Invalid input type")
src/main.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import settings
3
+ from fastapi import FastAPI, Request, status
4
+ from fastapi.exceptions import RequestValidationError
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.responses import JSONResponse, RedirectResponse
7
+ from itr.router import init_model, init_vectordb
8
+ from itr.router import router as router
9
+
10
+ app = FastAPI(title="Text-to-image Retrieval API")
11
+
12
+
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=settings.CORS_ORIGINS,
16
+ allow_headers=settings.CORS_HEADERS,
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ )
20
+
21
+
22
+ @app.exception_handler(RequestValidationError)
23
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
24
+ # Get the original 'detail' list of errors
25
+ details = exc.errors()
26
+ error_details = []
27
+
28
+ for error in details:
29
+ error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
30
+ return JSONResponse(content={"message": error_details})
31
+
32
+
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ init_vectordb(
36
+ index_file_path=settings.INDEX_FILE_PATH,
37
+ index_subframes_file_path=settings.INDEX_SUBFRAMES_FILE_PATH,
38
+ keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
39
+ subframes_groups_json_path=settings.SUBFRAMES_GROUPS_JSON_PATH,
40
+ )
41
+ device = (
42
+ "cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
43
+ )
44
+ init_model(model_name=settings.MODEL_NAME, device=device)
45
+
46
+
47
+ @app.get("/", include_in_schema=False)
48
+ async def root() -> None:
49
+ return RedirectResponse("/docs")
50
+
51
+
52
+ @app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
53
+ async def perform_healthcheck() -> None:
54
+ return JSONResponse(content={"message": "success"})
55
+
56
+
57
+ app.include_router(router)
58
+
59
+
60
+ # Start API
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+
64
+ uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)