Spaces:
Sleeping
Sleeping
Futyn-Maker
commited on
Commit
•
7e1f5f6
1
Parent(s):
f2e0a2c
Deploy the app
Browse files- .gitattributes +2 -0
- README.md +1 -1
- config.yaml +23 -0
- data/raw/textmeme.json +3 -0
- indexes/bm25/bm25_index.pkl +3 -0
- indexes/semantic/embeddings.npy +3 -0
- meme_search.db +3 -0
- requirements.txt +169 -0
- scripts/build_bm25_index.py +79 -0
- scripts/build_semantic_index.py +76 -0
- scripts/data_collector.py +77 -0
- scripts/make_db.py +100 -0
- src/db/crud.py +210 -0
- src/db/models.py +28 -0
- src/indexing/bm25_indexer.py +40 -0
- src/indexing/semantic_indexer.py +52 -0
- src/interface.py +124 -0
- src/main.py +104 -0
- src/parsing/vk_meme_parser.py +139 -0
- src/preprocessing/__pycache__/mystem_tokenizer.cpython-311.pyc +0 -0
- src/preprocessing/mystem_tokenizer.py +46 -0
- src/search/bm25_search.py +89 -0
- src/search/semantic_search.py +85 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
|
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.1.0
|
8 |
-
app_file: interface.py
|
9 |
pinned: false
|
10 |
license: wtfpl
|
11 |
short_description: Search for Russian-language memes by their text descriptions
|
|
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.1.0
|
8 |
+
app_file: src/interface.py
|
9 |
pinned: false
|
10 |
license: wtfpl
|
11 |
short_description: Search for Russian-language memes by their text descriptions
|
config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration file for the Meme Search Engine project
|
2 |
+
|
3 |
+
vk_parser:
|
4 |
+
api_token: "YOUR_TOKEN_HERE"
|
5 |
+
meme_pages:
|
6 |
+
- "textmeme"
|
7 |
+
# - "badtextmeme"
|
8 |
+
|
9 |
+
data_folders:
|
10 |
+
raw_data: "data/raw"
|
11 |
+
# images: "data/images"
|
12 |
+
|
13 |
+
database:
|
14 |
+
url: "sqlite:///./meme_search.db"
|
15 |
+
|
16 |
+
index_folders:
|
17 |
+
bm25: "indexes/bm25"
|
18 |
+
semantic: "indexes/semantic"
|
19 |
+
|
20 |
+
semantic_search:
|
21 |
+
model: "intfloat/multilingual-e5-small"
|
22 |
+
query_prefix: "query: "
|
23 |
+
document_prefix: "passage: "
|
data/raw/textmeme.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96ef95553f897ae10cbe0fcca505091ff41aef726840f989dd2d944ffefcc5d0
|
3 |
+
size 6394668
|
indexes/bm25/bm25_index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ca6b40f2349502f78f852b68e1c6589d5777cc34ffafb3935c12f2ea6b931e5
|
3 |
+
size 3973591
|
indexes/semantic/embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5509dd737152797eee7d5fc2e6ba90ce059ecb7c8c1632d5d7ff17bb38dddbd
|
3 |
+
size 10537088
|
meme_search.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbe21217d65e17c68c806d6baed8ac9da166fd65d34e332f93033f357b8d65ad
|
3 |
+
size 6897664
|
requirements.txt
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
antlr4-python3-runtime==4.9.3
|
4 |
+
anyio==4.6.2.post1
|
5 |
+
argon2-cffi==23.1.0
|
6 |
+
argon2-cffi-bindings==21.2.0
|
7 |
+
arrow==1.3.0
|
8 |
+
asttokens==2.4.1
|
9 |
+
async-lru==2.0.4
|
10 |
+
attrs==24.2.0
|
11 |
+
babel==2.16.0
|
12 |
+
beautifulsoup4==4.12.3
|
13 |
+
bleach==6.1.0
|
14 |
+
certifi==2024.8.30
|
15 |
+
cffi==1.17.1
|
16 |
+
charset-normalizer==3.4.0
|
17 |
+
click==8.1.7
|
18 |
+
comm==0.2.2
|
19 |
+
debugpy==1.8.7
|
20 |
+
decorator==5.1.1
|
21 |
+
defusedxml==0.7.1
|
22 |
+
eval_type_backport==0.2.0
|
23 |
+
executing==2.1.0
|
24 |
+
fastapi==0.115.2
|
25 |
+
fastjsonschema==2.20.0
|
26 |
+
ffmpy==0.4.0
|
27 |
+
filelock==3.16.1
|
28 |
+
fire==0.7.0
|
29 |
+
fqdn==1.5.1
|
30 |
+
fsspec==2024.9.0
|
31 |
+
gradio==5.1.0
|
32 |
+
gradio_client==1.4.0
|
33 |
+
greenlet==3.1.1
|
34 |
+
h11==0.14.0
|
35 |
+
httpcore==1.0.6
|
36 |
+
httpx==0.27.2
|
37 |
+
huggingface-hub==0.25.2
|
38 |
+
idna==3.10
|
39 |
+
ipykernel==6.29.5
|
40 |
+
ipython==8.28.0
|
41 |
+
ipywidgets==8.1.5
|
42 |
+
isoduration==20.11.0
|
43 |
+
jedi==0.19.1
|
44 |
+
Jinja2==3.1.4
|
45 |
+
joblib==1.4.2
|
46 |
+
json5==0.9.25
|
47 |
+
jsonpath-python==1.0.6
|
48 |
+
jsonpointer==3.0.0
|
49 |
+
jsonschema==4.23.0
|
50 |
+
jsonschema-specifications==2024.10.1
|
51 |
+
jupyter==1.1.1
|
52 |
+
jupyter-console==6.6.3
|
53 |
+
jupyter-events==0.10.0
|
54 |
+
jupyter-lsp==2.2.5
|
55 |
+
jupyter_client==8.6.3
|
56 |
+
jupyter_core==5.7.2
|
57 |
+
jupyter_server==2.14.2
|
58 |
+
jupyter_server_terminals==0.5.3
|
59 |
+
jupyterlab==4.2.5
|
60 |
+
jupyterlab_pygments==0.3.0
|
61 |
+
jupyterlab_server==2.27.3
|
62 |
+
jupyterlab_widgets==3.0.13
|
63 |
+
loguru==0.7.2
|
64 |
+
markdown-it-py==3.0.0
|
65 |
+
MarkupSafe==2.1.5
|
66 |
+
matplotlib-inline==0.1.7
|
67 |
+
mdurl==0.1.2
|
68 |
+
mistralai==1.1.0
|
69 |
+
mistune==3.0.2
|
70 |
+
mpmath==1.3.0
|
71 |
+
mypy-extensions==1.0.0
|
72 |
+
nbclient==0.10.0
|
73 |
+
nbconvert==7.16.4
|
74 |
+
nbformat==5.10.4
|
75 |
+
nest-asyncio==1.6.0
|
76 |
+
networkx==3.4.1
|
77 |
+
nltk==3.9.1
|
78 |
+
notebook==7.2.2
|
79 |
+
notebook_shim==0.2.4
|
80 |
+
numpy==2.1.2
|
81 |
+
nvidia-cublas-cu12==12.4.5.8
|
82 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
83 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
84 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
85 |
+
nvidia-cudnn-cu12==9.1.0.70
|
86 |
+
nvidia-cufft-cu12==11.2.1.3
|
87 |
+
nvidia-curand-cu12==10.3.5.147
|
88 |
+
nvidia-cusolver-cu12==11.6.1.9
|
89 |
+
nvidia-cusparse-cu12==12.3.1.170
|
90 |
+
nvidia-nccl-cu12==2.21.5
|
91 |
+
nvidia-nvjitlink-cu12==12.4.127
|
92 |
+
nvidia-nvtx-cu12==12.4.127
|
93 |
+
omegaconf==2.3.0
|
94 |
+
orjson==3.10.7
|
95 |
+
overrides==7.7.0
|
96 |
+
packaging==24.1
|
97 |
+
pandas==2.2.3
|
98 |
+
pandocfilters==1.5.1
|
99 |
+
parso==0.8.4
|
100 |
+
pexpect==4.9.0
|
101 |
+
pillow==10.4.0
|
102 |
+
platformdirs==4.3.6
|
103 |
+
prometheus_client==0.21.0
|
104 |
+
prompt_toolkit==3.0.48
|
105 |
+
psutil==6.1.0
|
106 |
+
ptyprocess==0.7.0
|
107 |
+
pure_eval==0.2.3
|
108 |
+
pycparser==2.22
|
109 |
+
pydantic==2.9.2
|
110 |
+
pydantic_core==2.23.4
|
111 |
+
pydub==0.25.1
|
112 |
+
Pygments==2.18.0
|
113 |
+
pymystem3==0.2.0
|
114 |
+
python-dateutil==2.8.2
|
115 |
+
python-json-logger==2.0.7
|
116 |
+
python-multipart==0.0.12
|
117 |
+
pytz==2024.2
|
118 |
+
PyYAML==6.0.2
|
119 |
+
pyzmq==26.2.0
|
120 |
+
rank-bm25==0.2.2
|
121 |
+
referencing==0.35.1
|
122 |
+
regex==2024.9.11
|
123 |
+
requests==2.32.3
|
124 |
+
rfc3339-validator==0.1.4
|
125 |
+
rfc3986-validator==0.1.1
|
126 |
+
rich==13.9.2
|
127 |
+
rpds-py==0.20.0
|
128 |
+
ruff==0.7.0
|
129 |
+
safetensors==0.4.5
|
130 |
+
scikit-learn==1.5.2
|
131 |
+
scipy==1.14.1
|
132 |
+
semantic-version==2.10.0
|
133 |
+
Send2Trash==1.8.3
|
134 |
+
sentence-transformers==3.2.0
|
135 |
+
shellingham==1.5.4
|
136 |
+
six==1.16.0
|
137 |
+
sniffio==1.3.1
|
138 |
+
soupsieve==2.6
|
139 |
+
SQLAlchemy==2.0.36
|
140 |
+
stack-data==0.6.3
|
141 |
+
starlette==0.40.0
|
142 |
+
sympy==1.13.1
|
143 |
+
termcolor==2.5.0
|
144 |
+
terminado==0.18.1
|
145 |
+
threadpoolctl==3.5.0
|
146 |
+
tinycss2==1.3.0
|
147 |
+
tokenizers==0.20.1
|
148 |
+
tomlkit==0.12.0
|
149 |
+
torch==2.5.0
|
150 |
+
tornado==6.4.1
|
151 |
+
tqdm==4.66.5
|
152 |
+
traitlets==5.14.3
|
153 |
+
transformers==4.45.2
|
154 |
+
triton==3.1.0
|
155 |
+
typer==0.12.5
|
156 |
+
types-python-dateutil==2.9.0.20241003
|
157 |
+
typing-inspect==0.9.0
|
158 |
+
typing_extensions==4.12.2
|
159 |
+
tzdata==2024.2
|
160 |
+
uri-template==1.3.0
|
161 |
+
urllib3==2.2.3
|
162 |
+
uvicorn==0.32.0
|
163 |
+
vk-api==11.9.9
|
164 |
+
wcwidth==0.2.13
|
165 |
+
webcolors==24.8.0
|
166 |
+
webencodings==0.5.1
|
167 |
+
websocket-client==1.8.0
|
168 |
+
websockets==12.0
|
169 |
+
widgetsnbextension==4.0.13
|
scripts/build_bm25_index.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Any, List
|
5 |
+
|
6 |
+
from loguru import logger
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sqlalchemy import create_engine
|
9 |
+
from sqlalchemy.orm import sessionmaker
|
10 |
+
|
11 |
+
|
12 |
+
def get_meme_corpus(db, crud) -> List[str]:
|
13 |
+
"""
|
14 |
+
Retrieve all meme texts from the database.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
db: Database session.
|
18 |
+
crud: CRUD operations module.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
List[str]: List of meme texts.
|
22 |
+
"""
|
23 |
+
memes = crud.get_all_memes(db)
|
24 |
+
corpus = [meme.text for meme in memes]
|
25 |
+
logger.info(f"Retrieved {len(corpus)} memes from the database")
|
26 |
+
return corpus
|
27 |
+
|
28 |
+
|
29 |
+
def build_bm25_index(corpus: List[str],
|
30 |
+
config: Dict[str,
|
31 |
+
Any],
|
32 |
+
mystem_tokenizer,
|
33 |
+
BM25Indexer):
|
34 |
+
"""
|
35 |
+
Build and save the BM25 index.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
corpus (List[str]): List of meme texts.
|
39 |
+
config (Dict[str, Any]): Configuration dictionary.
|
40 |
+
mystem_tokenizer: MystemTokenizer instance.
|
41 |
+
BM25Indexer: BM25Indexer class.
|
42 |
+
"""
|
43 |
+
indexer = BM25Indexer(corpus, mystem_tokenizer.tokenize)
|
44 |
+
bm25_index_folder = config['index_folders']['bm25']
|
45 |
+
os.makedirs(bm25_index_folder, exist_ok=True)
|
46 |
+
indexer.create_index(bm25_index_folder)
|
47 |
+
logger.info(f"BM25S index created and saved in {bm25_index_folder}")
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
from src.db import crud
|
52 |
+
from src.preprocessing.mystem_tokenizer import MystemTokenizer
|
53 |
+
from src.indexing.bm25_indexer import BM25Indexer
|
54 |
+
|
55 |
+
logger.add("logs/build_bm25s_index.log", rotation="10 MB")
|
56 |
+
|
57 |
+
# Load configuration
|
58 |
+
config = OmegaConf.load('config.yaml')
|
59 |
+
config = OmegaConf.to_container(config)
|
60 |
+
|
61 |
+
engine = create_engine(config['database']['url'])
|
62 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
63 |
+
db = SessionLocal()
|
64 |
+
|
65 |
+
try:
|
66 |
+
corpus = get_meme_corpus(db, crud)
|
67 |
+
mystem_tokenizer = MystemTokenizer()
|
68 |
+
build_bm25_index(corpus, config, mystem_tokenizer, BM25Indexer)
|
69 |
+
finally:
|
70 |
+
db.close()
|
71 |
+
|
72 |
+
logger.info("BM25S index building completed")
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
# Set up project root path
|
77 |
+
project_root = Path(__file__).resolve().parents[1]
|
78 |
+
sys.path.insert(0, str(project_root))
|
79 |
+
main()
|
scripts/build_semantic_index.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Any, List
|
5 |
+
|
6 |
+
from loguru import logger
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sqlalchemy import create_engine
|
9 |
+
from sqlalchemy.orm import sessionmaker
|
10 |
+
|
11 |
+
|
12 |
+
def get_meme_corpus(db, crud) -> List[str]:
|
13 |
+
"""
|
14 |
+
Retrieve all meme texts from the database.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
db: Database session.
|
18 |
+
crud: CRUD operations module.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
List[str]: List of meme texts.
|
22 |
+
"""
|
23 |
+
memes = crud.get_all_memes(db)
|
24 |
+
corpus = [meme.text for meme in memes]
|
25 |
+
logger.info(f"Retrieved {len(corpus)} memes from the database")
|
26 |
+
return corpus
|
27 |
+
|
28 |
+
|
29 |
+
def build_semantic_index(
|
30 |
+
corpus: List[str], config: Dict[str, Any], SemanticIndexer):
|
31 |
+
"""
|
32 |
+
Build and save the semantic index.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
corpus (List[str]): List of meme texts.
|
36 |
+
config (Dict[str, Any]): Configuration dictionary.
|
37 |
+
SemanticIndexer: SemanticIndexer class.
|
38 |
+
"""
|
39 |
+
model = config['semantic_search']['model']
|
40 |
+
prefix = config['semantic_search']['document_prefix']
|
41 |
+
indexer = SemanticIndexer(corpus, model=model, prefix=prefix)
|
42 |
+
|
43 |
+
semantic_index_folder = config['index_folders']['semantic']
|
44 |
+
os.makedirs(semantic_index_folder, exist_ok=True)
|
45 |
+
indexer.create_index(semantic_index_folder)
|
46 |
+
logger.info(f"Semantic index created and saved in {semantic_index_folder}")
|
47 |
+
|
48 |
+
|
49 |
+
def main():
|
50 |
+
from src.db import crud
|
51 |
+
from src.indexing.semantic_indexer import SemanticIndexer
|
52 |
+
|
53 |
+
logger.add("logs/build_semantic_index.log", rotation="10 MB")
|
54 |
+
|
55 |
+
# Load configuration
|
56 |
+
config = OmegaConf.load('config.yaml')
|
57 |
+
config = OmegaConf.to_container(config)
|
58 |
+
|
59 |
+
engine = create_engine(config['database']['url'])
|
60 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
61 |
+
db = SessionLocal()
|
62 |
+
|
63 |
+
try:
|
64 |
+
corpus = get_meme_corpus(db, crud)
|
65 |
+
build_semantic_index(corpus, config, SemanticIndexer)
|
66 |
+
finally:
|
67 |
+
db.close()
|
68 |
+
|
69 |
+
logger.info("Semantic index building completed")
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
# Set up project root path
|
74 |
+
project_root = Path(__file__).resolve().parents[1]
|
75 |
+
sys.path.insert(0, str(project_root))
|
76 |
+
main()
|
scripts/data_collector.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, Any
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
|
11 |
+
def process_public(parser, public_id: str, config: Dict[str, Any]) -> None:
|
12 |
+
"""
|
13 |
+
Process a single public page, updating or creating its JSON file.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
parser: VK meme parser instance.
|
17 |
+
public_id (str): ID our short name of the public page.
|
18 |
+
config (Dict[str, Any]): Configuration dictionary.
|
19 |
+
"""
|
20 |
+
raw_data_path = config['data_folders']['raw_data']
|
21 |
+
json_file_path = os.path.join(raw_data_path, f"{public_id}.json")
|
22 |
+
|
23 |
+
logger.info(f"Processing public: {public_id}")
|
24 |
+
|
25 |
+
memes_data = parser.get_memes(public_id)
|
26 |
+
|
27 |
+
if os.path.exists(json_file_path):
|
28 |
+
# Update existing JSON file
|
29 |
+
with open(json_file_path, 'r', encoding='utf-8') as file:
|
30 |
+
existing_data = json.load(file)
|
31 |
+
|
32 |
+
existing_posts = {post['id']: post for post in existing_data['posts']}
|
33 |
+
new_posts = [post for post in memes_data['posts']
|
34 |
+
if post['id'] not in existing_posts]
|
35 |
+
|
36 |
+
# Add new posts to the beginning of the list
|
37 |
+
existing_data['posts'] = new_posts + existing_data['posts']
|
38 |
+
|
39 |
+
with open(json_file_path, 'w', encoding='utf-8') as file:
|
40 |
+
json.dump(existing_data, file, ensure_ascii=False, indent=2)
|
41 |
+
|
42 |
+
logger.info(f"Updated {len(new_posts)} new posts for {public_id}")
|
43 |
+
|
44 |
+
else:
|
45 |
+
# Create new JSON file
|
46 |
+
with open(json_file_path, 'w', encoding='utf-8') as file:
|
47 |
+
json.dump(memes_data, file, ensure_ascii=False, indent=2)
|
48 |
+
|
49 |
+
logger.info(
|
50 |
+
f"Created new JSON file for {public_id} with {len(memes_data['posts'])} posts")
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
from src.parsing.vk_meme_parser import VKMemeParser
|
55 |
+
|
56 |
+
logger.add("logs/data_collector.log", rotation="10 MB")
|
57 |
+
|
58 |
+
# Load configuration
|
59 |
+
config = OmegaConf.load('config.yaml')
|
60 |
+
config = OmegaConf.to_container(config)
|
61 |
+
|
62 |
+
parser = VKMemeParser(config['vk_parser']['api_token'])
|
63 |
+
|
64 |
+
for folder in config['data_folders'].values():
|
65 |
+
os.makedirs(folder, exist_ok=True)
|
66 |
+
|
67 |
+
for public_id in config['vk_parser']['meme_pages']:
|
68 |
+
process_public(parser, public_id, config)
|
69 |
+
|
70 |
+
logger.info("Data collection process completed")
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
# Set up project root path
|
75 |
+
project_root = Path(__file__).resolve().parents[1]
|
76 |
+
sys.path.insert(0, str(project_root))
|
77 |
+
main()
|
scripts/make_db.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, Any, List
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from sqlalchemy import create_engine
|
10 |
+
from sqlalchemy.orm import sessionmaker
|
11 |
+
|
12 |
+
|
13 |
+
def process_json_files(
|
14 |
+
raw_data_path: str) -> tuple[List[Dict[str, str]], List[Dict[str, Any]]]:
|
15 |
+
"""
|
16 |
+
Process all JSON files in the raw data folder.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
raw_data_path (str): Path to the folder containing JSON files.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
tuple: Lists of public and meme data to be added to the database.
|
23 |
+
"""
|
24 |
+
publics_to_add: List[Dict[str, str]] = []
|
25 |
+
memes_to_add: List[Dict[str, Any]] = []
|
26 |
+
|
27 |
+
for filename in os.listdir(raw_data_path):
|
28 |
+
if filename.endswith('.json'):
|
29 |
+
public_vk = filename[:-5] # Remove .json extension
|
30 |
+
file_path = os.path.join(raw_data_path, filename)
|
31 |
+
|
32 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
33 |
+
data = json.load(file)
|
34 |
+
|
35 |
+
publics_to_add.append({
|
36 |
+
"public_vk": public_vk,
|
37 |
+
"public_name": data['name']
|
38 |
+
})
|
39 |
+
|
40 |
+
for post in data['posts']:
|
41 |
+
memes_to_add.append({
|
42 |
+
"public_vk": public_vk,
|
43 |
+
"text": post['text'],
|
44 |
+
"image_url": post['image_url']
|
45 |
+
})
|
46 |
+
|
47 |
+
logger.info(
|
48 |
+
f"Processed file: {filename}, found {len(data['posts'])} memes")
|
49 |
+
|
50 |
+
return publics_to_add, memes_to_add
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
from src.db.models import Base
|
55 |
+
from src.db import crud
|
56 |
+
|
57 |
+
logger.add("logs/make_db.log", rotation="10 MB")
|
58 |
+
|
59 |
+
# Load configuration
|
60 |
+
config = OmegaConf.load('config.yaml')
|
61 |
+
config = OmegaConf.to_container(config)
|
62 |
+
|
63 |
+
engine = create_engine(config['database']['url'])
|
64 |
+
|
65 |
+
# Drop all existing tables and create new ones
|
66 |
+
Base.metadata.drop_all(bind=engine)
|
67 |
+
Base.metadata.create_all(bind=engine)
|
68 |
+
|
69 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
70 |
+
db = SessionLocal()
|
71 |
+
|
72 |
+
raw_data_path = config['data_folders']['raw_data']
|
73 |
+
|
74 |
+
publics_to_add, memes_to_add = process_json_files(raw_data_path)
|
75 |
+
|
76 |
+
# Add all publics to the database
|
77 |
+
added_publics = crud.add_publics(db, publics_to_add)
|
78 |
+
|
79 |
+
# Create a mapping of public_vk to public_id
|
80 |
+
public_vk_to_id = {public.public_vk: public.id for public in added_publics}
|
81 |
+
|
82 |
+
# Update memes with correct public_id
|
83 |
+
for meme in memes_to_add:
|
84 |
+
meme['public_id'] = public_vk_to_id[meme.pop('public_vk')]
|
85 |
+
|
86 |
+
# Add all memes to the database
|
87 |
+
crud.add_memes(db, memes_to_add)
|
88 |
+
|
89 |
+
logger.info(
|
90 |
+
f"Added {len(added_publics)} publics and {len(memes_to_add)} memes to the database")
|
91 |
+
|
92 |
+
db.close()
|
93 |
+
logger.info("Database population completed")
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
# Set up project root path
|
98 |
+
project_root = Path(__file__).resolve().parents[1]
|
99 |
+
sys.path.insert(0, str(project_root))
|
100 |
+
main()
|
src/db/crud.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any
|
2 |
+
|
3 |
+
from sqlalchemy.orm import Session
|
4 |
+
|
5 |
+
from . import models
|
6 |
+
|
7 |
+
|
8 |
+
def add_publics(
|
9 |
+
db: Session, publics: List[Dict[str, str]]) -> List[models.Public]:
|
10 |
+
"""
|
11 |
+
Add multiple public pages to the database.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
db (Session): The database session.
|
15 |
+
publics (List[Dict[str, str]]): List of public page data.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
List[models.Public]: List of added public page objects.
|
19 |
+
"""
|
20 |
+
db_publics = [models.Public(**public) for public in publics]
|
21 |
+
db.add_all(db_publics)
|
22 |
+
db.commit()
|
23 |
+
for public in db_publics:
|
24 |
+
db.refresh(public)
|
25 |
+
return db_publics
|
26 |
+
|
27 |
+
|
28 |
+
def add_memes(db: Session, memes: List[Dict[str, Any]]) -> List[models.Meme]:
|
29 |
+
"""
|
30 |
+
Add multiple memes to the database.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
db (Session): The database session.
|
34 |
+
memes (List[Dict[str, Any]]): List of meme data.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
List[models.Meme]: List of added meme objects.
|
38 |
+
"""
|
39 |
+
db_memes = [models.Meme(**meme) for meme in memes]
|
40 |
+
db.add_all(db_memes)
|
41 |
+
db.commit()
|
42 |
+
for meme in db_memes:
|
43 |
+
db.refresh(meme)
|
44 |
+
return db_memes
|
45 |
+
|
46 |
+
|
47 |
+
def get_memes_by_publics(db: Session,
|
48 |
+
public_ids: List[int]) -> List[models.Meme]:
|
49 |
+
"""
|
50 |
+
Retrieve memes associated with specific public pages.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
db (Session): The database session.
|
54 |
+
public_ids (List[int]): List of public page IDs.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
List[models.Meme]: List of meme objects.
|
58 |
+
"""
|
59 |
+
return db.query(
|
60 |
+
models.Meme).filter(
|
61 |
+
models.Meme.public_id.in_(public_ids)).all()
|
62 |
+
|
63 |
+
|
64 |
+
def get_all_memes(db: Session) -> List[models.Meme]:
|
65 |
+
"""
|
66 |
+
Retrieve all memes from the database.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
db (Session): The database session.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
List[models.Meme]: List of all meme objects.
|
73 |
+
"""
|
74 |
+
return db.query(models.Meme).all()
|
75 |
+
|
76 |
+
|
77 |
+
def get_all_publics(db: Session) -> List[models.Public]:
|
78 |
+
"""
|
79 |
+
Retrieve all public pages from the database.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
db (Session): The database session.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
List[models.Public]: List of all public page objects.
|
86 |
+
"""
|
87 |
+
return db.query(models.Public).all()
|
88 |
+
|
89 |
+
|
90 |
+
def get_memes_by_ids(db: Session, meme_ids: List[int]) -> List[models.Meme]:
|
91 |
+
"""
|
92 |
+
Retrieve memes by their IDs.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
db (Session): The database session.
|
96 |
+
meme_ids (List[int]): List of meme IDs.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
List[models.Meme]: List of meme objects.
|
100 |
+
"""
|
101 |
+
return db.query(models.Meme).filter(models.Meme.id.in_(meme_ids)).all()
|
102 |
+
|
103 |
+
|
104 |
+
def get_publics_by_ids(db: Session,
|
105 |
+
public_ids: List[int]) -> List[models.Public]:
|
106 |
+
"""
|
107 |
+
Retrieve public pages by their IDs.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
db (Session): The database session.
|
111 |
+
public_ids (List[int]): List of public page IDs.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
List[models.Public]: List of public page objects.
|
115 |
+
"""
|
116 |
+
return db.query(
|
117 |
+
models.Public).filter(
|
118 |
+
models.Public.id.in_(public_ids)).all()
|
119 |
+
|
120 |
+
|
121 |
+
def delete_memes(db: Session, meme_ids: List[int]) -> int:
|
122 |
+
"""
|
123 |
+
Delete memes by their IDs.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
db (Session): The database session.
|
127 |
+
meme_ids (List[int]): List of meme IDs to delete.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
int: Number of deleted memes.
|
131 |
+
"""
|
132 |
+
deleted_count = db.query(
|
133 |
+
models.Meme).filter(
|
134 |
+
models.Meme.id.in_(meme_ids)).delete(
|
135 |
+
synchronize_session='fetch')
|
136 |
+
db.commit()
|
137 |
+
return deleted_count
|
138 |
+
|
139 |
+
|
140 |
+
def delete_publics(db: Session, public_ids: List[int]) -> int:
|
141 |
+
"""
|
142 |
+
Delete public pages and their associated memes.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
db (Session): The database session.
|
146 |
+
public_ids (List[int]): List of public page IDs to delete.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
int: Number of deleted public pages.
|
150 |
+
"""
|
151 |
+
# First, delete associated memes
|
152 |
+
db.query(
|
153 |
+
models.Meme).filter(
|
154 |
+
models.Meme.public_id.in_(public_ids)).delete(
|
155 |
+
synchronize_session='fetch')
|
156 |
+
|
157 |
+
# Then delete the publics
|
158 |
+
deleted_count = db.query(
|
159 |
+
models.Public).filter(
|
160 |
+
models.Public.id.in_(public_ids)).delete(
|
161 |
+
synchronize_session='fetch')
|
162 |
+
db.commit()
|
163 |
+
return deleted_count
|
164 |
+
|
165 |
+
|
166 |
+
def get_memes_with_public_info(
|
167 |
+
db: Session, meme_ids: List[int]) -> List[tuple[models.Meme, models.Public]]:
|
168 |
+
"""
|
169 |
+
Retrieve memes with their associated public page information.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
db (Session): The database session.
|
173 |
+
meme_ids (List[int]): List of meme IDs.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
List[tuple[models.Meme, models.Public]]: List of tuples containing meme and public page objects.
|
177 |
+
"""
|
178 |
+
return db.query(models.Meme, models.Public).\
|
179 |
+
join(models.Public, models.Meme.public_id == models.Public.id).\
|
180 |
+
filter(models.Meme.id.in_(meme_ids)).all()
|
181 |
+
|
182 |
+
|
183 |
+
def update_memes(db: Session, meme_updates: List[Dict[str, Any]]) -> None:
|
184 |
+
"""
|
185 |
+
Update multiple memes in the database.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
db (Session): The database session.
|
189 |
+
meme_updates (List[Dict[str, Any]]): List of meme update data.
|
190 |
+
"""
|
191 |
+
for update in meme_updates:
|
192 |
+
meme_id = update.pop('id')
|
193 |
+
db.query(models.Meme).filter(models.Meme.id == meme_id). update({getattr(
|
194 |
+
models.Meme, k): v for k, v in update.items()}, synchronize_session='fetch')
|
195 |
+
db.commit()
|
196 |
+
|
197 |
+
|
198 |
+
def update_publics(db: Session, public_updates: List[Dict[str, Any]]) -> None:
|
199 |
+
"""
|
200 |
+
Update multiple public pages in the database.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
db (Session): The database session.
|
204 |
+
public_updates (List[Dict[str, Any]]): List of public page update data.
|
205 |
+
"""
|
206 |
+
for update in public_updates:
|
207 |
+
public_id = update.pop('id')
|
208 |
+
db.query(models.Public).filter(models.Public.id == public_id). update({getattr(
|
209 |
+
models.Public, k): v for k, v in update.items()}, synchronize_session='fetch')
|
210 |
+
db.commit()
|
src/db/models.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import Column, Integer, String, ForeignKey
|
2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
3 |
+
from sqlalchemy.orm import relationship
|
4 |
+
|
5 |
+
|
6 |
+
Base = declarative_base()
|
7 |
+
|
8 |
+
|
9 |
+
class Public(Base):
|
10 |
+
__tablename__ = "publics"
|
11 |
+
|
12 |
+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
13 |
+
public_vk = Column(String, unique=True, index=True)
|
14 |
+
public_name = Column(String)
|
15 |
+
|
16 |
+
memes = relationship("Meme", back_populates="public")
|
17 |
+
|
18 |
+
|
19 |
+
class Meme(Base):
|
20 |
+
__tablename__ = "memes"
|
21 |
+
|
22 |
+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
23 |
+
public_id = Column(Integer, ForeignKey("publics.id"))
|
24 |
+
text = Column(String)
|
25 |
+
image_url = Column(String)
|
26 |
+
local_image_path = Column(String)
|
27 |
+
|
28 |
+
public = relationship("Public", back_populates="memes")
|
src/indexing/bm25_indexer.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from typing import List, Callable
|
4 |
+
|
5 |
+
from rank_bm25 import BM25Okapi
|
6 |
+
|
7 |
+
|
8 |
+
class BM25Indexer:
|
9 |
+
def __init__(self, corpus: List[str],
|
10 |
+
tokenizer: Callable[[str], List[str]]):
|
11 |
+
"""
|
12 |
+
Initialize the BM25Indexer.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
corpus (List[str]): The corpus to be indexed.
|
16 |
+
tokenizer (Callable[[str], List[str]]): A function to tokenize the text.
|
17 |
+
"""
|
18 |
+
self.corpus = corpus
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.bm25 = None
|
21 |
+
|
22 |
+
def create_index(self, save_dir: str) -> None:
|
23 |
+
"""
|
24 |
+
Create and save the BM25 index.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
save_dir (str): Directory to save the index.
|
28 |
+
"""
|
29 |
+
# Ensure the save directory exists
|
30 |
+
os.makedirs(save_dir, exist_ok=True)
|
31 |
+
|
32 |
+
# Tokenize the corpus
|
33 |
+
tokenized_corpus = [self.tokenizer(doc) for doc in self.corpus]
|
34 |
+
|
35 |
+
# Create the BM25 model
|
36 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
37 |
+
|
38 |
+
# Save the BM25 index
|
39 |
+
with open(os.path.join(save_dir, 'bm25_index.pkl'), 'wb') as f:
|
40 |
+
pickle.dump(self.bm25, f)
|
src/indexing/semantic_indexer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
|
8 |
+
class SemanticIndexer:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
corpus: List[str],
|
12 |
+
model: str,
|
13 |
+
prefix: Optional[str] = None
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Initialize the SemanticIndexer.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
corpus (List[str]): The corpus to be indexed.
|
20 |
+
model (str): The name or path of the SentenceTransformer model to use.
|
21 |
+
prefix (Optional[str], optional): A prefix to add to each text in the corpus. Defaults to None.
|
22 |
+
"""
|
23 |
+
self.corpus = corpus
|
24 |
+
self.model = SentenceTransformer(model)
|
25 |
+
self.prefix = prefix
|
26 |
+
self.embeddings = None
|
27 |
+
|
28 |
+
def create_index(self, save_dir: str) -> None:
|
29 |
+
"""
|
30 |
+
Create and save the semantic index.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
save_dir (str): Directory to save the embeddings.
|
34 |
+
"""
|
35 |
+
# Ensure the save directory exists
|
36 |
+
os.makedirs(save_dir, exist_ok=True)
|
37 |
+
|
38 |
+
# Prepare texts with prefix if provided
|
39 |
+
texts = [
|
40 |
+
f"{self.prefix}{text}" if self.prefix else text for text in self.corpus]
|
41 |
+
|
42 |
+
# Create embeddings
|
43 |
+
self.embeddings = self.model.encode(
|
44 |
+
texts,
|
45 |
+
show_progress_bar=True,
|
46 |
+
convert_to_numpy=True,
|
47 |
+
normalize_embeddings=True
|
48 |
+
)
|
49 |
+
|
50 |
+
# Save embeddings
|
51 |
+
embeddings_file = os.path.join(save_dir, "embeddings.npy")
|
52 |
+
np.save(embeddings_file, self.embeddings)
|
src/interface.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from sqlalchemy import create_engine
|
7 |
+
from sqlalchemy.orm import sessionmaker
|
8 |
+
|
9 |
+
|
10 |
+
def initialize_search_engines(db, config):
|
11 |
+
"""
|
12 |
+
Initialize both BM25 and Semantic search engines.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
db: Database session.
|
16 |
+
config: Configuration dictionary.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple: Initialized BM25Search and SemanticSearch engines.
|
20 |
+
"""
|
21 |
+
from search.bm25_search import BM25Search
|
22 |
+
from search.semantic_search import SemanticSearch
|
23 |
+
from preprocessing.mystem_tokenizer import MystemTokenizer
|
24 |
+
|
25 |
+
custom_tokenizer = MystemTokenizer()
|
26 |
+
bm25_search = BM25Search(
|
27 |
+
db,
|
28 |
+
config['index_folders']['bm25'],
|
29 |
+
custom_tokenizer.tokenize
|
30 |
+
)
|
31 |
+
semantic_search = SemanticSearch(
|
32 |
+
db,
|
33 |
+
model=config['semantic_search']['model'],
|
34 |
+
embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
|
35 |
+
prefix=config['semantic_search']['query_prefix'])
|
36 |
+
|
37 |
+
return bm25_search, semantic_search
|
38 |
+
|
39 |
+
|
40 |
+
def search_memes(query: str, search_type: str, num_results: int):
|
41 |
+
"""
|
42 |
+
Search for memes using the specified search method.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
query (str): The search query.
|
46 |
+
search_type (str): The type of search to perform. Either 'BM25' or 'Семантический'.
|
47 |
+
num_results (int): The number of results to return.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
tuple: A tuple containing the search results and search time.
|
51 |
+
"""
|
52 |
+
if search_type == "BM25":
|
53 |
+
results = bm25_search.search(query, num_results)
|
54 |
+
else:
|
55 |
+
results = semantic_search.search(query, num_results)
|
56 |
+
|
57 |
+
output = []
|
58 |
+
for result in results['results']:
|
59 |
+
output.append((result['image_url'], result['text']))
|
60 |
+
|
61 |
+
return output, f"Время поиска: {results['search_time']:.4f} секунд"
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
global bm25_search, semantic_search
|
66 |
+
|
67 |
+
# Load configuration
|
68 |
+
config = OmegaConf.load('config.yaml')
|
69 |
+
config = OmegaConf.to_container(config)
|
70 |
+
|
71 |
+
# Initialize database session
|
72 |
+
engine = create_engine(config['database']['url'])
|
73 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
74 |
+
db = SessionLocal()
|
75 |
+
|
76 |
+
# Initialize search engines
|
77 |
+
bm25_search, semantic_search = initialize_search_engines(db, config)
|
78 |
+
|
79 |
+
# Gradio interface
|
80 |
+
with gr.Blocks() as demo:
|
81 |
+
gr.Markdown("# Поиск мемов")
|
82 |
+
gr.Markdown(
|
83 |
+
"Добро пожаловать в приложение для поиска мемов! Введите запрос, выберите тип поиска и количество результатов."
|
84 |
+
)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
query = gr.Textbox(label="Запрос")
|
88 |
+
search_type = gr.Radio(
|
89 |
+
["BM25", "Семантический"],
|
90 |
+
label="Тип поиска",
|
91 |
+
value="BM25"
|
92 |
+
)
|
93 |
+
num_results = gr.Slider(
|
94 |
+
minimum=1,
|
95 |
+
maximum=10,
|
96 |
+
step=1,
|
97 |
+
value=1,
|
98 |
+
label="Количество результатов"
|
99 |
+
)
|
100 |
+
|
101 |
+
search_button = gr.Button("Найти")
|
102 |
+
|
103 |
+
output_gallery = gr.Gallery(
|
104 |
+
label="Результаты",
|
105 |
+
show_label=False,
|
106 |
+
columns=3,
|
107 |
+
height=400
|
108 |
+
)
|
109 |
+
output_time = gr.Markdown()
|
110 |
+
|
111 |
+
search_button.click(
|
112 |
+
fn=search_memes,
|
113 |
+
inputs=[query, search_type, num_results],
|
114 |
+
outputs=[output_gallery, output_time]
|
115 |
+
)
|
116 |
+
|
117 |
+
demo.launch()
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
# Set up project root path
|
122 |
+
project_root = Path(__file__).resolve().parents[1]
|
123 |
+
sys.path.insert(0, str(project_root))
|
124 |
+
main()
|
src/main.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import fire
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from sqlalchemy import create_engine
|
7 |
+
from sqlalchemy.orm import sessionmaker
|
8 |
+
|
9 |
+
|
10 |
+
def initialize_bm25_search(db, config):
|
11 |
+
"""
|
12 |
+
Initialize BM25 search engine.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
db: Database session.
|
16 |
+
config: Configuration dictionary.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
BM25Search: Initialized BM25 search engine.
|
20 |
+
"""
|
21 |
+
from search.bm25_search import BM25Search
|
22 |
+
from preprocessing.mystem_tokenizer import MystemTokenizer
|
23 |
+
|
24 |
+
custom_tokenizer = MystemTokenizer()
|
25 |
+
return BM25Search(
|
26 |
+
db,
|
27 |
+
config['index_folders']['bm25'],
|
28 |
+
custom_tokenizer.tokenize
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def initialize_semantic_search(db, config):
|
33 |
+
"""
|
34 |
+
Initialize semantic search engine.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
db: Database session.
|
38 |
+
config: Configuration dictionary.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
SemanticSearch: Initialized semantic search engine.
|
42 |
+
"""
|
43 |
+
from search.semantic_search import SemanticSearch
|
44 |
+
return SemanticSearch(
|
45 |
+
db,
|
46 |
+
model=config['semantic_search']['model'],
|
47 |
+
embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
|
48 |
+
prefix=config['semantic_search']['query_prefix'])
|
49 |
+
|
50 |
+
|
51 |
+
def search_memes(query: str, search_type: str = 'bm25', num: int = 1):
|
52 |
+
"""
|
53 |
+
Search for memes using the specified search method.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
query (str): The search query.
|
57 |
+
search_type (str): The type of search to perform. Either 'bm25' or 'semantic'. Defaults to 'bm25s'.
|
58 |
+
num (int): The number of results to return. Defaults to 1.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
None: Prints the results to the console.
|
62 |
+
"""
|
63 |
+
if not query:
|
64 |
+
print("Error: Query is required.")
|
65 |
+
return
|
66 |
+
if search_type not in ['bm25', 'semantic']:
|
67 |
+
print("Error: Invalid search type. Use 'bm25' or 'semantic'.")
|
68 |
+
return
|
69 |
+
if num < 1:
|
70 |
+
print("Error: Number of results must be at least 1.")
|
71 |
+
return
|
72 |
+
|
73 |
+
# Load configuration
|
74 |
+
config = OmegaConf.load('config.yaml')
|
75 |
+
config = OmegaConf.to_container(config)
|
76 |
+
|
77 |
+
# Initialize database session
|
78 |
+
engine = create_engine(config['database']['url'])
|
79 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
80 |
+
db = SessionLocal()
|
81 |
+
|
82 |
+
try:
|
83 |
+
# Initialize search engine
|
84 |
+
if search_type == 'bm25':
|
85 |
+
search_engine = initialize_bm25_search(db, config)
|
86 |
+
elif search_type == 'semantic':
|
87 |
+
search_engine = initialize_semantic_search(db, config)
|
88 |
+
|
89 |
+
# Perform search
|
90 |
+
results = search_engine.search(query, num)
|
91 |
+
|
92 |
+
# Print results
|
93 |
+
for result in results['results']:
|
94 |
+
print(result['text'])
|
95 |
+
print(f"\nSearch time: {results['search_time']:.4f} seconds")
|
96 |
+
finally:
|
97 |
+
db.close()
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
# Set up project root path
|
102 |
+
project_root = Path(__file__).resolve().parents[1]
|
103 |
+
sys.path.insert(0, str(project_root))
|
104 |
+
fire.Fire(search_memes)
|
src/parsing/vk_meme_parser.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Dict, Any
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import vk_api
|
7 |
+
|
8 |
+
|
9 |
+
class VKMemeParser:
|
10 |
+
def __init__(self, token: str):
|
11 |
+
"""
|
12 |
+
Initialize the VK Meme Parser.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
token (str): VK API access token.
|
16 |
+
"""
|
17 |
+
self.vk_session = vk_api.VkApi(token=token)
|
18 |
+
self.vk = self.vk_session.get_api()
|
19 |
+
|
20 |
+
def _process_post(self, post: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
21 |
+
"""
|
22 |
+
Process a single post and extract relevant information.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
post (Dict[str, Any]): A dictionary containing post data.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Optional[Dict[str, Any]]: A dictionary with post ID, text, and image URL if valid,
|
29 |
+
None otherwise.
|
30 |
+
"""
|
31 |
+
# Check if the post is valid
|
32 |
+
if (post.get("marked_as_ads") or
|
33 |
+
"is_pinned" in post or
|
34 |
+
"copy_history" in post or
|
35 |
+
len(post.get("attachments", [])) != 1 or
|
36 |
+
post["attachments"][0]["type"] != "photo"):
|
37 |
+
return None
|
38 |
+
|
39 |
+
post_id = post["id"]
|
40 |
+
text = post["text"].strip()
|
41 |
+
|
42 |
+
# Get the largest available photo
|
43 |
+
photo_sizes = post["attachments"][0]["photo"]["sizes"]
|
44 |
+
largest_photo = max(
|
45 |
+
photo_sizes,
|
46 |
+
key=lambda x: x["width"] * x["height"])
|
47 |
+
image_url = largest_photo["url"]
|
48 |
+
|
49 |
+
return {
|
50 |
+
"id": post_id,
|
51 |
+
"text": text,
|
52 |
+
"image_url": image_url
|
53 |
+
}
|
54 |
+
|
55 |
+
def get_memes(self, public_id: str) -> Dict[str, Any]:
|
56 |
+
"""
|
57 |
+
Retrieve and process all meme posts from a specified public page.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
public_id (str): ID or short name of the public page.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Dict[str, Any]: A dictionary containing the public's name and processed meme posts.
|
64 |
+
"""
|
65 |
+
memes = []
|
66 |
+
|
67 |
+
# Determine whether to use domain or owner_id
|
68 |
+
if public_id.isdigit() or (public_id.startswith("-")
|
69 |
+
and public_id[1:].isdigit()):
|
70 |
+
params: Dict[str, Any] = {"owner_id": int(public_id)}
|
71 |
+
else:
|
72 |
+
params: Dict[str, Any] = {"domain": public_id}
|
73 |
+
|
74 |
+
# Fetch public's name
|
75 |
+
group_info = self.vk.groups.getById(group_id=public_id)[0]
|
76 |
+
group_name = group_info['name']
|
77 |
+
|
78 |
+
# Process posts
|
79 |
+
offset = 0
|
80 |
+
while True:
|
81 |
+
# Fetch 100 posts at a time
|
82 |
+
params["count"] = 100
|
83 |
+
params["offset"] = offset
|
84 |
+
response = self.vk.wall.get(**params)
|
85 |
+
|
86 |
+
posts = response["items"]
|
87 |
+
|
88 |
+
for post in posts:
|
89 |
+
processed_post = self._process_post(post)
|
90 |
+
if processed_post:
|
91 |
+
memes.append(processed_post)
|
92 |
+
|
93 |
+
# Check if we've reached the end of posts
|
94 |
+
if len(posts) < 100:
|
95 |
+
break
|
96 |
+
|
97 |
+
offset = response["next_from"]
|
98 |
+
|
99 |
+
return {
|
100 |
+
"name": group_name,
|
101 |
+
"posts": memes
|
102 |
+
}
|
103 |
+
|
104 |
+
def download_image(
|
105 |
+
self,
|
106 |
+
image_url: str,
|
107 |
+
folder_path: str) -> Optional[str]:
|
108 |
+
"""
|
109 |
+
Download an image from the given URL and save it to the specified folder.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
image_url (str): The URL of the image to download.
|
113 |
+
folder_path (str): The path to the folder where the image should be saved.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Optional[str]: The path to the saved image file, or None if the download failed.
|
117 |
+
"""
|
118 |
+
try:
|
119 |
+
# Create the folder if it doesn't exist
|
120 |
+
os.makedirs(folder_path, exist_ok=True)
|
121 |
+
|
122 |
+
filename = os.path.basename(urlparse(image_url).path)
|
123 |
+
if not os.path.splitext(filename)[1]:
|
124 |
+
return None
|
125 |
+
|
126 |
+
image_path = os.path.join(folder_path, filename)
|
127 |
+
|
128 |
+
response = requests.get(image_url, stream=True)
|
129 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
130 |
+
|
131 |
+
with open(image_path, 'wb') as file:
|
132 |
+
for chunk in response.iter_content(chunk_size=8192):
|
133 |
+
file.write(chunk)
|
134 |
+
|
135 |
+
return filename
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Error downloading image from {image_url}: {str(e)}")
|
139 |
+
return None
|
src/preprocessing/__pycache__/mystem_tokenizer.cpython-311.pyc
ADDED
Binary file (2.47 kB). View file
|
|
src/preprocessing/mystem_tokenizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import nltk
|
4 |
+
from pymystem3 import Mystem
|
5 |
+
|
6 |
+
|
7 |
+
class MystemTokenizer:
|
8 |
+
def __init__(self, stopwords: Union[List[str], str] = "ru"):
|
9 |
+
"""
|
10 |
+
Initialize the MystemTokenizer.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
stopwords (Union[List[str], str]): Either a list of stopwords or "ru" for Russian stopwords.
|
14 |
+
"""
|
15 |
+
if stopwords == "ru":
|
16 |
+
try:
|
17 |
+
self.stopwords = nltk.corpus.stopwords.words("russian")
|
18 |
+
except LookupError:
|
19 |
+
# Download stopwords if not available
|
20 |
+
nltk.download("stopwords")
|
21 |
+
self.stopwords = nltk.corpus.stopwords.words("russian")
|
22 |
+
else:
|
23 |
+
self.stopwords = stopwords
|
24 |
+
|
25 |
+
self.mystem = Mystem()
|
26 |
+
|
27 |
+
def tokenize(self, text: str) -> List[str]:
|
28 |
+
"""
|
29 |
+
Tokenize and lemmatize the input text, removing stopwords.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
text (str): The input text to tokenize.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List[str]: A list of lemmatized tokens.
|
36 |
+
"""
|
37 |
+
# Lemmatize and tokenize using Mystem
|
38 |
+
lemmas = self.mystem.lemmatize(text.lower())
|
39 |
+
|
40 |
+
# Filter out non-letter tokens and stopwords
|
41 |
+
tokens = [
|
42 |
+
lemma for lemma in lemmas
|
43 |
+
if lemma.isalpha() and lemma not in self.stopwords
|
44 |
+
]
|
45 |
+
|
46 |
+
return tokens
|
src/search/bm25_search.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Any, Callable
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from sqlalchemy.orm import Session
|
8 |
+
from rank_bm25 import BM25Okapi
|
9 |
+
|
10 |
+
from src.db import crud
|
11 |
+
|
12 |
+
|
13 |
+
class BM25Search:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
db: Session,
|
17 |
+
index_folder: str,
|
18 |
+
tokenizer: Callable[[str], List[str]]
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Initialize the BM25Search.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
db (Session): The database session.
|
25 |
+
index_folder (str): The folder containing the BM25 index.
|
26 |
+
tokenizer (Callable[[str], List[str]]): A function to tokenize the text.
|
27 |
+
"""
|
28 |
+
self.db = db
|
29 |
+
self.tokenizer = tokenizer
|
30 |
+
self.bm25 = self._load_index(index_folder)
|
31 |
+
|
32 |
+
def _load_index(self, index_folder: str) -> BM25Okapi:
|
33 |
+
"""
|
34 |
+
Load the BM25 index from a file.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
index_folder (str): The folder containing the BM25 index.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
BM25Okapi: The loaded BM25 index.
|
41 |
+
"""
|
42 |
+
with open(os.path.join(index_folder, 'bm25_index.pkl'), 'rb') as f:
|
43 |
+
return pickle.load(f)
|
44 |
+
|
45 |
+
def search(self, query: str, n: int = 3) -> Dict[str, Any]:
|
46 |
+
"""
|
47 |
+
Perform a search using BM25.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
query (str): The search query.
|
51 |
+
n (int, optional): The number of results to return. Defaults to 3.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Dict[str, Any]: A dictionary containing search results and search time.
|
55 |
+
"""
|
56 |
+
start_time = time.time()
|
57 |
+
|
58 |
+
# Tokenize the query
|
59 |
+
query_tokens = self.tokenizer(query)
|
60 |
+
|
61 |
+
# Retrieve scores for all documents
|
62 |
+
scores = self.bm25.get_scores(query_tokens)
|
63 |
+
|
64 |
+
# Get top n document indices
|
65 |
+
top_n_indices = np.argsort(scores)[-n:][::-1]
|
66 |
+
top_n_scores = scores[top_n_indices]
|
67 |
+
|
68 |
+
# Adjust indices to match database IDs (assuming IDs start from 1)
|
69 |
+
db_ids = top_n_indices + 1
|
70 |
+
|
71 |
+
# Retrieve memes from the database
|
72 |
+
memes = crud.get_memes_by_ids(self.db, db_ids.tolist())
|
73 |
+
|
74 |
+
# Format the results
|
75 |
+
results = [
|
76 |
+
{
|
77 |
+
"id": meme.id,
|
78 |
+
"public_id": meme.public_id,
|
79 |
+
"text": meme.text,
|
80 |
+
"image_url": meme.image_url,
|
81 |
+
"score": top_n_scores[db_ids.tolist().index(meme.id)]
|
82 |
+
}
|
83 |
+
for meme in memes
|
84 |
+
]
|
85 |
+
|
86 |
+
return {
|
87 |
+
"results": results,
|
88 |
+
"search_time": time.time() - start_time
|
89 |
+
}
|
src/search/semantic_search.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from sqlalchemy.orm import Session
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
+
from src.db import crud
|
9 |
+
|
10 |
+
|
11 |
+
class SemanticSearch:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
db: Session,
|
15 |
+
model: str,
|
16 |
+
embeddings_file: str,
|
17 |
+
prefix: Optional[str] = None
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Initialize the SemanticSearch.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
db (Session): The database session.
|
24 |
+
model (str): The name or path of the SentenceTransformer model to use.
|
25 |
+
embeddings_file (str): Path to the file containing pre-computed embeddings.
|
26 |
+
prefix (Optional[str], optional): A prefix to add to each query. Defaults to None.
|
27 |
+
"""
|
28 |
+
self.db = db
|
29 |
+
self.model = SentenceTransformer(model)
|
30 |
+
self.prefix = prefix
|
31 |
+
self.embeddings = np.load(embeddings_file)
|
32 |
+
|
33 |
+
def search(self, query: str, n: int = 3) -> Dict[str, Any]:
|
34 |
+
"""
|
35 |
+
Perform a semantic search.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
query (str): The search query.
|
39 |
+
n (int, optional): The number of results to return. Defaults to 3.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Dict[str, Any]: A dictionary containing search results and search time.
|
43 |
+
"""
|
44 |
+
start_time = time.time()
|
45 |
+
|
46 |
+
# Prepare query with prefix if provided
|
47 |
+
query_text = f"{self.prefix}{query}" if self.prefix else query
|
48 |
+
|
49 |
+
# Encode the query
|
50 |
+
query_embedding = self.model.encode(
|
51 |
+
[query_text],
|
52 |
+
convert_to_numpy=True,
|
53 |
+
normalize_embeddings=True
|
54 |
+
)[0]
|
55 |
+
|
56 |
+
# Compute similarity scores
|
57 |
+
scores = np.dot(self.embeddings, query_embedding)
|
58 |
+
|
59 |
+
# Get top n results
|
60 |
+
top_n_indices = np.argsort(scores)[-n:][::-1]
|
61 |
+
top_n_scores = scores[top_n_indices]
|
62 |
+
|
63 |
+
# Adjust indices to match database IDs
|
64 |
+
db_ids = top_n_indices + 1
|
65 |
+
|
66 |
+
# Retrieve memes from the database
|
67 |
+
memes = crud.get_memes_by_ids(self.db, db_ids.tolist())
|
68 |
+
|
69 |
+
# Format the results
|
70 |
+
results = [
|
71 |
+
{
|
72 |
+
"id": meme.id,
|
73 |
+
"public_id": meme.public_id,
|
74 |
+
"text": meme.text,
|
75 |
+
"image_url": meme.image_url,
|
76 |
+
"local_image_path": meme.local_image_path,
|
77 |
+
"score": top_n_scores[db_ids.tolist().index(meme.id)]
|
78 |
+
}
|
79 |
+
for meme in memes
|
80 |
+
]
|
81 |
+
|
82 |
+
return {
|
83 |
+
"results": results,
|
84 |
+
"search_time": time.time() - start_time
|
85 |
+
}
|