Spaces:
Sleeping
Sleeping
ngxquang
commited on
Commit
·
52b1203
1
Parent(s):
afe0d05
feat: add subframes data for clip api
Browse files- .env +13 -0
- .env.example +13 -0
- .gitattributes +2 -0
- .gitignore +169 -0
- Dockerfile +33 -0
- data/config/keyframes_groups_L01_to_L36.json +3 -0
- data/config/subframes_groups_L01_to_L36.json +3 -0
- data/faiss-index/index_clip_L01_to_L36.faiss +3 -0
- data/faiss-index/index_clip_subframes_L01_to_L36.faiss +3 -0
- requirements.txt +16 -0
- src/__init__.py +0 -0
- src/config.py +30 -0
- src/itr/__init__.py +0 -0
- src/itr/dtb_cursor.py +59 -0
- src/itr/router.py +49 -0
- src/itr/vlm_model.py +30 -0
- src/main.py +64 -0
.env
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PROJECT INFORMATION
|
2 |
+
HOST=0.0.0.0
|
3 |
+
PORT=7860
|
4 |
+
CORS_HEADERS=["*"]
|
5 |
+
CORS_ORIGINS=["*"]
|
6 |
+
|
7 |
+
MODEL_NAME="ViT-B/32"
|
8 |
+
DEVICE="cpu" # ["cuda", "cpu"]
|
9 |
+
|
10 |
+
INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L36.faiss"
|
11 |
+
INDEX_SUBFRAMES_FILE_PATH="data/faiss-index/index_clip_subframes_L01_to_L36.faiss"
|
12 |
+
KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L36.json"
|
13 |
+
SUBFRAMES_GROUPS_JSON_PATH="data/config/subframes_groups_L01_to_L36.json"
|
.env.example
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PROJECT INFORMATION
|
2 |
+
HOST=0.0.0.0
|
3 |
+
PORT=7860
|
4 |
+
CORS_HEADERS=["*"]
|
5 |
+
CORS_ORIGINS=["*"]
|
6 |
+
|
7 |
+
MODEL_NAME="ViT-B/32"
|
8 |
+
DEVICE="cpu" # ["cuda", "cpu"]
|
9 |
+
|
10 |
+
INDEX_FILE_PATH="data/faiss-index/index_clip_L01_to_L36.faiss"
|
11 |
+
INDEX_SUBFRAMES_FILE_PATH="data/faiss-index/index_clip_subframes_L01_to_L36.faiss"
|
12 |
+
KEYFRAMES_GROUPS_JSON_PATH="data/config/keyframes_groups_L01_to_L36.json"
|
13 |
+
SUBFRAMES_GROUPS_JSON_PATH="data/config/subframes_groups_L01_to_L36.json"
|
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.venv
|
124 |
+
env/
|
125 |
+
venv/
|
126 |
+
ENV/
|
127 |
+
env.bak/
|
128 |
+
venv.bak/
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# mkdocs documentation
|
138 |
+
/site
|
139 |
+
|
140 |
+
# mypy
|
141 |
+
.mypy_cache/
|
142 |
+
.dmypy.json
|
143 |
+
dmypy.json
|
144 |
+
|
145 |
+
# Pyre type checker
|
146 |
+
.pyre/
|
147 |
+
|
148 |
+
# pytype static type analyzer
|
149 |
+
.pytype/
|
150 |
+
|
151 |
+
# Cython debug symbols
|
152 |
+
cython_debug/
|
153 |
+
|
154 |
+
# Model Checkpoitns
|
155 |
+
*.pth
|
156 |
+
|
157 |
+
#Sentencepiece Tokenizer
|
158 |
+
*.spm
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
166 |
+
|
167 |
+
*.zip
|
168 |
+
*.xlsx
|
169 |
+
/convert/submission
|
Dockerfile
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8-slim
|
2 |
+
|
3 |
+
RUN apt-get update && \
|
4 |
+
apt-get install git gsutil -y && \
|
5 |
+
apt clean && \
|
6 |
+
rm -rf /var/cache/apt/*
|
7 |
+
|
8 |
+
WORKDIR /code
|
9 |
+
|
10 |
+
COPY requirements.txt /code/requirements.txt
|
11 |
+
|
12 |
+
# PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
|
13 |
+
# PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
|
14 |
+
# PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
|
15 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
16 |
+
PYTHONUNBUFFERED=1 \
|
17 |
+
PYTHONIOENCODING=utf-8
|
18 |
+
|
19 |
+
RUN pip install -U pip && \
|
20 |
+
pip install --no-cache-dir -r /code/requirements.txt
|
21 |
+
|
22 |
+
RUN useradd -m -u 1000 user
|
23 |
+
|
24 |
+
USER user
|
25 |
+
|
26 |
+
ENV HOME=/home/user \
|
27 |
+
PATH=/home/user/.local/bin:$PATH
|
28 |
+
|
29 |
+
WORKDIR $HOME/app
|
30 |
+
|
31 |
+
COPY --chown=user . $HOME/app
|
32 |
+
|
33 |
+
CMD ["python", "./src/main.py"]
|
data/config/keyframes_groups_L01_to_L36.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a28d33542216ad24cb09db5f4fd1040c0c2045bcd42d8a4f5e1d038deac73db4
|
3 |
+
size 29038197
|
data/config/subframes_groups_L01_to_L36.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:980472aaad434482a2e89d5a8bc076a923b41c26437b597ceb6c7de34bc4f9c7
|
3 |
+
size 28967171
|
data/faiss-index/index_clip_L01_to_L36.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:614c04492f8bb40dab35d5317c1ee52b5a2fee78e92b2cc5bf71386817f63172
|
3 |
+
size 674996269
|
data/faiss-index/index_clip_subframes_L01_to_L36.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4542bbf7f47179b00b0a4dc7f577245490c86b0399235c39051b12eaafc2efa
|
3 |
+
size 671422509
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.103.1
|
2 |
+
uvicorn==0.23.2
|
3 |
+
pydantic-settings==2.0.3
|
4 |
+
|
5 |
+
|
6 |
+
# Models
|
7 |
+
torch==1.7.1
|
8 |
+
torchvision==0.8.2
|
9 |
+
ftfy==6.1.1
|
10 |
+
regex
|
11 |
+
tqdm==4.66.1
|
12 |
+
git+https://github.com/openai/CLIP.git@main
|
13 |
+
|
14 |
+
# Vector Database
|
15 |
+
faiss-cpu
|
16 |
+
|
src/__init__.py
ADDED
File without changes
|
src/config.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from pydantic_settings import BaseSettings
|
4 |
+
|
5 |
+
FILE = Path(__file__)
|
6 |
+
ROOT = FILE.parent.parent
|
7 |
+
|
8 |
+
|
9 |
+
class Settings(BaseSettings):
|
10 |
+
# API SETTINGS
|
11 |
+
HOST: str
|
12 |
+
PORT: int
|
13 |
+
CORS_ORIGINS: list
|
14 |
+
CORS_HEADERS: list
|
15 |
+
|
16 |
+
# MODEL SETTINGS
|
17 |
+
MODEL_NAME: str = "ViT-B/32"
|
18 |
+
DEVICE: str = "cpu"
|
19 |
+
|
20 |
+
# FAISS DATABASE SETTINGS
|
21 |
+
INDEX_FILE_PATH: str
|
22 |
+
INDEX_SUBFRAMES_FILE_PATH: str
|
23 |
+
KEYFRAMES_GROUPS_JSON_PATH: str
|
24 |
+
SUBFRAMES_GROUPS_JSON_PATH: str
|
25 |
+
|
26 |
+
class Config:
|
27 |
+
env_file = ROOT / ".env"
|
28 |
+
|
29 |
+
|
30 |
+
settings = Settings()
|
src/itr/__init__.py
ADDED
File without changes
|
src/itr/dtb_cursor.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import faiss
|
3 |
+
import os
|
4 |
+
|
5 |
+
from functools import lru_cache
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
class DatabaseCursor:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
index_file_path: str,
|
13 |
+
index_subframes_file_path: str,
|
14 |
+
keyframes_groups_json_path: str,
|
15 |
+
subframes_groups_json_path: str,
|
16 |
+
):
|
17 |
+
self._load_index(index_file_path, index_subframes_file_path)
|
18 |
+
self._load_keyframes_groups_info(
|
19 |
+
keyframes_groups_json_path, subframes_groups_json_path
|
20 |
+
)
|
21 |
+
|
22 |
+
@lru_cache(maxsize=1)
|
23 |
+
def _load_index(self, index_file_path, index_subframes_file_path):
|
24 |
+
self.index = faiss.read_index(index_file_path)
|
25 |
+
index_subframes = faiss.read_index(index_subframes_file_path)
|
26 |
+
try:
|
27 |
+
self.index.merge_from(index_subframes)
|
28 |
+
except:
|
29 |
+
raise Exception("dtb_cursor::cannot merge keyframes and subframes index")
|
30 |
+
|
31 |
+
@lru_cache(maxsize=1)
|
32 |
+
def _load_keyframes_groups_info(
|
33 |
+
self, keyframes_groups_json_path: str, subframes_groups_json_path: str
|
34 |
+
):
|
35 |
+
with open(keyframes_groups_json_path) as file:
|
36 |
+
keyframes_group_info = json.loads(file.read())
|
37 |
+
self.no_keyframes = len(keyframes_group_info)
|
38 |
+
with open(subframes_groups_json_path) as file:
|
39 |
+
subframes_groups_info = json.loads(file.read())
|
40 |
+
self.no_subframes = len(subframes_groups_info)
|
41 |
+
|
42 |
+
self.frames_groups_info = keyframes_group_info
|
43 |
+
self.frames_groups_info.extend(subframes_groups_info)
|
44 |
+
print(self.index.ntotal)
|
45 |
+
assert self.index.ntotal == len(
|
46 |
+
self.frames_groups_info
|
47 |
+
), "dtb_cursor::Index length and map lenght mismatch"
|
48 |
+
|
49 |
+
def kNN_search(self, query_vector: str, topk: int = 10):
|
50 |
+
results = []
|
51 |
+
distances, ids = self.index.search(query_vector, topk)
|
52 |
+
for i in range(len(ids[0])):
|
53 |
+
frame_detail = self.frames_groups_info[ids[0][i]]
|
54 |
+
frame_detail["distance"] = str(distances[0][i])
|
55 |
+
frame_detail["folder"] = (
|
56 |
+
"Keyframes" if ids[0][i] < self.no_keyframes else "Subframes"
|
57 |
+
)
|
58 |
+
results.append(frame_detail)
|
59 |
+
return results
|
src/itr/router.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, File, status
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from .dtb_cursor import DatabaseCursor
|
6 |
+
from .vlm_model import VisionLanguageModel
|
7 |
+
|
8 |
+
|
9 |
+
class Item(BaseModel):
|
10 |
+
query_text: str
|
11 |
+
topk: int
|
12 |
+
|
13 |
+
|
14 |
+
router = APIRouter()
|
15 |
+
|
16 |
+
|
17 |
+
vectordb_cursor = None
|
18 |
+
vlm_model = None
|
19 |
+
|
20 |
+
|
21 |
+
def init_vectordb(**kargs):
|
22 |
+
# Singleton pattern
|
23 |
+
global vectordb_cursor
|
24 |
+
if vectordb_cursor is None:
|
25 |
+
vectordb_cursor = DatabaseCursor(**kargs)
|
26 |
+
|
27 |
+
|
28 |
+
def init_model(**kargs):
|
29 |
+
# Singleton
|
30 |
+
global vlm_model
|
31 |
+
if vlm_model is None:
|
32 |
+
vlm_model = VisionLanguageModel(**kargs)
|
33 |
+
|
34 |
+
|
35 |
+
@router.post("/retrieval")
|
36 |
+
async def retrieve(item: Item) -> JSONResponse:
|
37 |
+
try:
|
38 |
+
query_vector = vlm_model.get_embedding(input=item.query_text)
|
39 |
+
search_results = vectordb_cursor.kNN_search(query_vector, item.topk)
|
40 |
+
except Exception:
|
41 |
+
return JSONResponse(
|
42 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
43 |
+
content={"message": "Search error"},
|
44 |
+
)
|
45 |
+
|
46 |
+
return JSONResponse(
|
47 |
+
status_code=status.HTTP_200_OK,
|
48 |
+
content={"message": "success", "details": search_results},
|
49 |
+
)
|
src/itr/vlm_model.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import clip
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class VisionLanguageModel:
|
9 |
+
def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
|
10 |
+
self._load_model(model_name, device)
|
11 |
+
self.device = device
|
12 |
+
|
13 |
+
@lru_cache(maxsize=1)
|
14 |
+
def _load_model(self, model_name, device: str = "cpu"):
|
15 |
+
self.model, self.processor = clip.load(model_name, device=device)
|
16 |
+
|
17 |
+
def get_embedding(self, input: Union[str, Image.Image]):
|
18 |
+
if isinstance(input, str):
|
19 |
+
tokens = clip.tokenize(input).to(self.device)
|
20 |
+
vector = self.model.encode_text(tokens)
|
21 |
+
vector /= vector.norm(dim=-1, keepdim=True)
|
22 |
+
vector = vector.cpu().detach().numpy().astype("float32")
|
23 |
+
return vector
|
24 |
+
elif isinstance(input, Image.Image):
|
25 |
+
image_input = self.preprocess(input).unsqueeze(0).to(self.device)
|
26 |
+
vector = self.model.encode_image(image_input)
|
27 |
+
vector /= vector.norm(dim=-1, keepdim=True)
|
28 |
+
return vector
|
29 |
+
else:
|
30 |
+
raise Exception("Invalid input type")
|
src/main.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config import settings
|
3 |
+
from fastapi import FastAPI, Request, status
|
4 |
+
from fastapi.exceptions import RequestValidationError
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from fastapi.responses import JSONResponse, RedirectResponse
|
7 |
+
from itr.router import init_model, init_vectordb
|
8 |
+
from itr.router import router as router
|
9 |
+
|
10 |
+
app = FastAPI(title="Text-to-image Retrieval API")
|
11 |
+
|
12 |
+
|
13 |
+
app.add_middleware(
|
14 |
+
CORSMiddleware,
|
15 |
+
allow_origins=settings.CORS_ORIGINS,
|
16 |
+
allow_headers=settings.CORS_HEADERS,
|
17 |
+
allow_credentials=True,
|
18 |
+
allow_methods=["*"],
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
@app.exception_handler(RequestValidationError)
|
23 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
24 |
+
# Get the original 'detail' list of errors
|
25 |
+
details = exc.errors()
|
26 |
+
error_details = []
|
27 |
+
|
28 |
+
for error in details:
|
29 |
+
error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
|
30 |
+
return JSONResponse(content={"message": error_details})
|
31 |
+
|
32 |
+
|
33 |
+
@app.on_event("startup")
|
34 |
+
async def startup_event():
|
35 |
+
init_vectordb(
|
36 |
+
index_file_path=settings.INDEX_FILE_PATH,
|
37 |
+
index_subframes_file_path=settings.INDEX_SUBFRAMES_FILE_PATH,
|
38 |
+
keyframes_groups_json_path=settings.KEYFRAMES_GROUPS_JSON_PATH,
|
39 |
+
subframes_groups_json_path=settings.SUBFRAMES_GROUPS_JSON_PATH,
|
40 |
+
)
|
41 |
+
device = (
|
42 |
+
"cuda" if settings.DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
|
43 |
+
)
|
44 |
+
init_model(model_name=settings.MODEL_NAME, device=device)
|
45 |
+
|
46 |
+
|
47 |
+
@app.get("/", include_in_schema=False)
|
48 |
+
async def root() -> None:
|
49 |
+
return RedirectResponse("/docs")
|
50 |
+
|
51 |
+
|
52 |
+
@app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
|
53 |
+
async def perform_healthcheck() -> None:
|
54 |
+
return JSONResponse(content={"message": "success"})
|
55 |
+
|
56 |
+
|
57 |
+
app.include_router(router)
|
58 |
+
|
59 |
+
|
60 |
+
# Start API
|
61 |
+
if __name__ == "__main__":
|
62 |
+
import uvicorn
|
63 |
+
|
64 |
+
uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)
|