Futyn-Maker commited on
Commit
7e1f5f6
1 Parent(s): f2e0a2c

Deploy the app

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
37
+ *.db filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
- app_file: interface.py
9
  pinned: false
10
  license: wtfpl
11
  short_description: Search for Russian-language memes by their text descriptions
 
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
+ app_file: src/interface.py
9
  pinned: false
10
  license: wtfpl
11
  short_description: Search for Russian-language memes by their text descriptions
config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Meme Search Engine project
2
+
3
+ vk_parser:
4
+ api_token: "YOUR_TOKEN_HERE"
5
+ meme_pages:
6
+ - "textmeme"
7
+ # - "badtextmeme"
8
+
9
+ data_folders:
10
+ raw_data: "data/raw"
11
+ # images: "data/images"
12
+
13
+ database:
14
+ url: "sqlite:///./meme_search.db"
15
+
16
+ index_folders:
17
+ bm25: "indexes/bm25"
18
+ semantic: "indexes/semantic"
19
+
20
+ semantic_search:
21
+ model: "intfloat/multilingual-e5-small"
22
+ query_prefix: "query: "
23
+ document_prefix: "passage: "
data/raw/textmeme.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96ef95553f897ae10cbe0fcca505091ff41aef726840f989dd2d944ffefcc5d0
3
+ size 6394668
indexes/bm25/bm25_index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ca6b40f2349502f78f852b68e1c6589d5777cc34ffafb3935c12f2ea6b931e5
3
+ size 3973591
indexes/semantic/embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5509dd737152797eee7d5fc2e6ba90ce059ecb7c8c1632d5d7ff17bb38dddbd
3
+ size 10537088
meme_search.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbe21217d65e17c68c806d6baed8ac9da166fd65d34e332f93033f357b8d65ad
3
+ size 6897664
requirements.txt ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.6.2.post1
5
+ argon2-cffi==23.1.0
6
+ argon2-cffi-bindings==21.2.0
7
+ arrow==1.3.0
8
+ asttokens==2.4.1
9
+ async-lru==2.0.4
10
+ attrs==24.2.0
11
+ babel==2.16.0
12
+ beautifulsoup4==4.12.3
13
+ bleach==6.1.0
14
+ certifi==2024.8.30
15
+ cffi==1.17.1
16
+ charset-normalizer==3.4.0
17
+ click==8.1.7
18
+ comm==0.2.2
19
+ debugpy==1.8.7
20
+ decorator==5.1.1
21
+ defusedxml==0.7.1
22
+ eval_type_backport==0.2.0
23
+ executing==2.1.0
24
+ fastapi==0.115.2
25
+ fastjsonschema==2.20.0
26
+ ffmpy==0.4.0
27
+ filelock==3.16.1
28
+ fire==0.7.0
29
+ fqdn==1.5.1
30
+ fsspec==2024.9.0
31
+ gradio==5.1.0
32
+ gradio_client==1.4.0
33
+ greenlet==3.1.1
34
+ h11==0.14.0
35
+ httpcore==1.0.6
36
+ httpx==0.27.2
37
+ huggingface-hub==0.25.2
38
+ idna==3.10
39
+ ipykernel==6.29.5
40
+ ipython==8.28.0
41
+ ipywidgets==8.1.5
42
+ isoduration==20.11.0
43
+ jedi==0.19.1
44
+ Jinja2==3.1.4
45
+ joblib==1.4.2
46
+ json5==0.9.25
47
+ jsonpath-python==1.0.6
48
+ jsonpointer==3.0.0
49
+ jsonschema==4.23.0
50
+ jsonschema-specifications==2024.10.1
51
+ jupyter==1.1.1
52
+ jupyter-console==6.6.3
53
+ jupyter-events==0.10.0
54
+ jupyter-lsp==2.2.5
55
+ jupyter_client==8.6.3
56
+ jupyter_core==5.7.2
57
+ jupyter_server==2.14.2
58
+ jupyter_server_terminals==0.5.3
59
+ jupyterlab==4.2.5
60
+ jupyterlab_pygments==0.3.0
61
+ jupyterlab_server==2.27.3
62
+ jupyterlab_widgets==3.0.13
63
+ loguru==0.7.2
64
+ markdown-it-py==3.0.0
65
+ MarkupSafe==2.1.5
66
+ matplotlib-inline==0.1.7
67
+ mdurl==0.1.2
68
+ mistralai==1.1.0
69
+ mistune==3.0.2
70
+ mpmath==1.3.0
71
+ mypy-extensions==1.0.0
72
+ nbclient==0.10.0
73
+ nbconvert==7.16.4
74
+ nbformat==5.10.4
75
+ nest-asyncio==1.6.0
76
+ networkx==3.4.1
77
+ nltk==3.9.1
78
+ notebook==7.2.2
79
+ notebook_shim==0.2.4
80
+ numpy==2.1.2
81
+ nvidia-cublas-cu12==12.4.5.8
82
+ nvidia-cuda-cupti-cu12==12.4.127
83
+ nvidia-cuda-nvrtc-cu12==12.4.127
84
+ nvidia-cuda-runtime-cu12==12.4.127
85
+ nvidia-cudnn-cu12==9.1.0.70
86
+ nvidia-cufft-cu12==11.2.1.3
87
+ nvidia-curand-cu12==10.3.5.147
88
+ nvidia-cusolver-cu12==11.6.1.9
89
+ nvidia-cusparse-cu12==12.3.1.170
90
+ nvidia-nccl-cu12==2.21.5
91
+ nvidia-nvjitlink-cu12==12.4.127
92
+ nvidia-nvtx-cu12==12.4.127
93
+ omegaconf==2.3.0
94
+ orjson==3.10.7
95
+ overrides==7.7.0
96
+ packaging==24.1
97
+ pandas==2.2.3
98
+ pandocfilters==1.5.1
99
+ parso==0.8.4
100
+ pexpect==4.9.0
101
+ pillow==10.4.0
102
+ platformdirs==4.3.6
103
+ prometheus_client==0.21.0
104
+ prompt_toolkit==3.0.48
105
+ psutil==6.1.0
106
+ ptyprocess==0.7.0
107
+ pure_eval==0.2.3
108
+ pycparser==2.22
109
+ pydantic==2.9.2
110
+ pydantic_core==2.23.4
111
+ pydub==0.25.1
112
+ Pygments==2.18.0
113
+ pymystem3==0.2.0
114
+ python-dateutil==2.8.2
115
+ python-json-logger==2.0.7
116
+ python-multipart==0.0.12
117
+ pytz==2024.2
118
+ PyYAML==6.0.2
119
+ pyzmq==26.2.0
120
+ rank-bm25==0.2.2
121
+ referencing==0.35.1
122
+ regex==2024.9.11
123
+ requests==2.32.3
124
+ rfc3339-validator==0.1.4
125
+ rfc3986-validator==0.1.1
126
+ rich==13.9.2
127
+ rpds-py==0.20.0
128
+ ruff==0.7.0
129
+ safetensors==0.4.5
130
+ scikit-learn==1.5.2
131
+ scipy==1.14.1
132
+ semantic-version==2.10.0
133
+ Send2Trash==1.8.3
134
+ sentence-transformers==3.2.0
135
+ shellingham==1.5.4
136
+ six==1.16.0
137
+ sniffio==1.3.1
138
+ soupsieve==2.6
139
+ SQLAlchemy==2.0.36
140
+ stack-data==0.6.3
141
+ starlette==0.40.0
142
+ sympy==1.13.1
143
+ termcolor==2.5.0
144
+ terminado==0.18.1
145
+ threadpoolctl==3.5.0
146
+ tinycss2==1.3.0
147
+ tokenizers==0.20.1
148
+ tomlkit==0.12.0
149
+ torch==2.5.0
150
+ tornado==6.4.1
151
+ tqdm==4.66.5
152
+ traitlets==5.14.3
153
+ transformers==4.45.2
154
+ triton==3.1.0
155
+ typer==0.12.5
156
+ types-python-dateutil==2.9.0.20241003
157
+ typing-inspect==0.9.0
158
+ typing_extensions==4.12.2
159
+ tzdata==2024.2
160
+ uri-template==1.3.0
161
+ urllib3==2.2.3
162
+ uvicorn==0.32.0
163
+ vk-api==11.9.9
164
+ wcwidth==0.2.13
165
+ webcolors==24.8.0
166
+ webencodings==0.5.1
167
+ websocket-client==1.8.0
168
+ websockets==12.0
169
+ widgetsnbextension==4.0.13
scripts/build_bm25_index.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Dict, Any, List
5
+
6
+ from loguru import logger
7
+ from omegaconf import OmegaConf
8
+ from sqlalchemy import create_engine
9
+ from sqlalchemy.orm import sessionmaker
10
+
11
+
12
+ def get_meme_corpus(db, crud) -> List[str]:
13
+ """
14
+ Retrieve all meme texts from the database.
15
+
16
+ Args:
17
+ db: Database session.
18
+ crud: CRUD operations module.
19
+
20
+ Returns:
21
+ List[str]: List of meme texts.
22
+ """
23
+ memes = crud.get_all_memes(db)
24
+ corpus = [meme.text for meme in memes]
25
+ logger.info(f"Retrieved {len(corpus)} memes from the database")
26
+ return corpus
27
+
28
+
29
+ def build_bm25_index(corpus: List[str],
30
+ config: Dict[str,
31
+ Any],
32
+ mystem_tokenizer,
33
+ BM25Indexer):
34
+ """
35
+ Build and save the BM25 index.
36
+
37
+ Args:
38
+ corpus (List[str]): List of meme texts.
39
+ config (Dict[str, Any]): Configuration dictionary.
40
+ mystem_tokenizer: MystemTokenizer instance.
41
+ BM25Indexer: BM25Indexer class.
42
+ """
43
+ indexer = BM25Indexer(corpus, mystem_tokenizer.tokenize)
44
+ bm25_index_folder = config['index_folders']['bm25']
45
+ os.makedirs(bm25_index_folder, exist_ok=True)
46
+ indexer.create_index(bm25_index_folder)
47
+ logger.info(f"BM25S index created and saved in {bm25_index_folder}")
48
+
49
+
50
+ def main():
51
+ from src.db import crud
52
+ from src.preprocessing.mystem_tokenizer import MystemTokenizer
53
+ from src.indexing.bm25_indexer import BM25Indexer
54
+
55
+ logger.add("logs/build_bm25s_index.log", rotation="10 MB")
56
+
57
+ # Load configuration
58
+ config = OmegaConf.load('config.yaml')
59
+ config = OmegaConf.to_container(config)
60
+
61
+ engine = create_engine(config['database']['url'])
62
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
63
+ db = SessionLocal()
64
+
65
+ try:
66
+ corpus = get_meme_corpus(db, crud)
67
+ mystem_tokenizer = MystemTokenizer()
68
+ build_bm25_index(corpus, config, mystem_tokenizer, BM25Indexer)
69
+ finally:
70
+ db.close()
71
+
72
+ logger.info("BM25S index building completed")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ # Set up project root path
77
+ project_root = Path(__file__).resolve().parents[1]
78
+ sys.path.insert(0, str(project_root))
79
+ main()
scripts/build_semantic_index.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Dict, Any, List
5
+
6
+ from loguru import logger
7
+ from omegaconf import OmegaConf
8
+ from sqlalchemy import create_engine
9
+ from sqlalchemy.orm import sessionmaker
10
+
11
+
12
+ def get_meme_corpus(db, crud) -> List[str]:
13
+ """
14
+ Retrieve all meme texts from the database.
15
+
16
+ Args:
17
+ db: Database session.
18
+ crud: CRUD operations module.
19
+
20
+ Returns:
21
+ List[str]: List of meme texts.
22
+ """
23
+ memes = crud.get_all_memes(db)
24
+ corpus = [meme.text for meme in memes]
25
+ logger.info(f"Retrieved {len(corpus)} memes from the database")
26
+ return corpus
27
+
28
+
29
+ def build_semantic_index(
30
+ corpus: List[str], config: Dict[str, Any], SemanticIndexer):
31
+ """
32
+ Build and save the semantic index.
33
+
34
+ Args:
35
+ corpus (List[str]): List of meme texts.
36
+ config (Dict[str, Any]): Configuration dictionary.
37
+ SemanticIndexer: SemanticIndexer class.
38
+ """
39
+ model = config['semantic_search']['model']
40
+ prefix = config['semantic_search']['document_prefix']
41
+ indexer = SemanticIndexer(corpus, model=model, prefix=prefix)
42
+
43
+ semantic_index_folder = config['index_folders']['semantic']
44
+ os.makedirs(semantic_index_folder, exist_ok=True)
45
+ indexer.create_index(semantic_index_folder)
46
+ logger.info(f"Semantic index created and saved in {semantic_index_folder}")
47
+
48
+
49
+ def main():
50
+ from src.db import crud
51
+ from src.indexing.semantic_indexer import SemanticIndexer
52
+
53
+ logger.add("logs/build_semantic_index.log", rotation="10 MB")
54
+
55
+ # Load configuration
56
+ config = OmegaConf.load('config.yaml')
57
+ config = OmegaConf.to_container(config)
58
+
59
+ engine = create_engine(config['database']['url'])
60
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
61
+ db = SessionLocal()
62
+
63
+ try:
64
+ corpus = get_meme_corpus(db, crud)
65
+ build_semantic_index(corpus, config, SemanticIndexer)
66
+ finally:
67
+ db.close()
68
+
69
+ logger.info("Semantic index building completed")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ # Set up project root path
74
+ project_root = Path(__file__).resolve().parents[1]
75
+ sys.path.insert(0, str(project_root))
76
+ main()
scripts/data_collector.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Dict, Any
6
+
7
+ from loguru import logger
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def process_public(parser, public_id: str, config: Dict[str, Any]) -> None:
12
+ """
13
+ Process a single public page, updating or creating its JSON file.
14
+
15
+ Args:
16
+ parser: VK meme parser instance.
17
+ public_id (str): ID our short name of the public page.
18
+ config (Dict[str, Any]): Configuration dictionary.
19
+ """
20
+ raw_data_path = config['data_folders']['raw_data']
21
+ json_file_path = os.path.join(raw_data_path, f"{public_id}.json")
22
+
23
+ logger.info(f"Processing public: {public_id}")
24
+
25
+ memes_data = parser.get_memes(public_id)
26
+
27
+ if os.path.exists(json_file_path):
28
+ # Update existing JSON file
29
+ with open(json_file_path, 'r', encoding='utf-8') as file:
30
+ existing_data = json.load(file)
31
+
32
+ existing_posts = {post['id']: post for post in existing_data['posts']}
33
+ new_posts = [post for post in memes_data['posts']
34
+ if post['id'] not in existing_posts]
35
+
36
+ # Add new posts to the beginning of the list
37
+ existing_data['posts'] = new_posts + existing_data['posts']
38
+
39
+ with open(json_file_path, 'w', encoding='utf-8') as file:
40
+ json.dump(existing_data, file, ensure_ascii=False, indent=2)
41
+
42
+ logger.info(f"Updated {len(new_posts)} new posts for {public_id}")
43
+
44
+ else:
45
+ # Create new JSON file
46
+ with open(json_file_path, 'w', encoding='utf-8') as file:
47
+ json.dump(memes_data, file, ensure_ascii=False, indent=2)
48
+
49
+ logger.info(
50
+ f"Created new JSON file for {public_id} with {len(memes_data['posts'])} posts")
51
+
52
+
53
+ def main():
54
+ from src.parsing.vk_meme_parser import VKMemeParser
55
+
56
+ logger.add("logs/data_collector.log", rotation="10 MB")
57
+
58
+ # Load configuration
59
+ config = OmegaConf.load('config.yaml')
60
+ config = OmegaConf.to_container(config)
61
+
62
+ parser = VKMemeParser(config['vk_parser']['api_token'])
63
+
64
+ for folder in config['data_folders'].values():
65
+ os.makedirs(folder, exist_ok=True)
66
+
67
+ for public_id in config['vk_parser']['meme_pages']:
68
+ process_public(parser, public_id, config)
69
+
70
+ logger.info("Data collection process completed")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ # Set up project root path
75
+ project_root = Path(__file__).resolve().parents[1]
76
+ sys.path.insert(0, str(project_root))
77
+ main()
scripts/make_db.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Dict, Any, List
6
+
7
+ from loguru import logger
8
+ from omegaconf import OmegaConf
9
+ from sqlalchemy import create_engine
10
+ from sqlalchemy.orm import sessionmaker
11
+
12
+
13
+ def process_json_files(
14
+ raw_data_path: str) -> tuple[List[Dict[str, str]], List[Dict[str, Any]]]:
15
+ """
16
+ Process all JSON files in the raw data folder.
17
+
18
+ Args:
19
+ raw_data_path (str): Path to the folder containing JSON files.
20
+
21
+ Returns:
22
+ tuple: Lists of public and meme data to be added to the database.
23
+ """
24
+ publics_to_add: List[Dict[str, str]] = []
25
+ memes_to_add: List[Dict[str, Any]] = []
26
+
27
+ for filename in os.listdir(raw_data_path):
28
+ if filename.endswith('.json'):
29
+ public_vk = filename[:-5] # Remove .json extension
30
+ file_path = os.path.join(raw_data_path, filename)
31
+
32
+ with open(file_path, 'r', encoding='utf-8') as file:
33
+ data = json.load(file)
34
+
35
+ publics_to_add.append({
36
+ "public_vk": public_vk,
37
+ "public_name": data['name']
38
+ })
39
+
40
+ for post in data['posts']:
41
+ memes_to_add.append({
42
+ "public_vk": public_vk,
43
+ "text": post['text'],
44
+ "image_url": post['image_url']
45
+ })
46
+
47
+ logger.info(
48
+ f"Processed file: {filename}, found {len(data['posts'])} memes")
49
+
50
+ return publics_to_add, memes_to_add
51
+
52
+
53
+ def main():
54
+ from src.db.models import Base
55
+ from src.db import crud
56
+
57
+ logger.add("logs/make_db.log", rotation="10 MB")
58
+
59
+ # Load configuration
60
+ config = OmegaConf.load('config.yaml')
61
+ config = OmegaConf.to_container(config)
62
+
63
+ engine = create_engine(config['database']['url'])
64
+
65
+ # Drop all existing tables and create new ones
66
+ Base.metadata.drop_all(bind=engine)
67
+ Base.metadata.create_all(bind=engine)
68
+
69
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
70
+ db = SessionLocal()
71
+
72
+ raw_data_path = config['data_folders']['raw_data']
73
+
74
+ publics_to_add, memes_to_add = process_json_files(raw_data_path)
75
+
76
+ # Add all publics to the database
77
+ added_publics = crud.add_publics(db, publics_to_add)
78
+
79
+ # Create a mapping of public_vk to public_id
80
+ public_vk_to_id = {public.public_vk: public.id for public in added_publics}
81
+
82
+ # Update memes with correct public_id
83
+ for meme in memes_to_add:
84
+ meme['public_id'] = public_vk_to_id[meme.pop('public_vk')]
85
+
86
+ # Add all memes to the database
87
+ crud.add_memes(db, memes_to_add)
88
+
89
+ logger.info(
90
+ f"Added {len(added_publics)} publics and {len(memes_to_add)} memes to the database")
91
+
92
+ db.close()
93
+ logger.info("Database population completed")
94
+
95
+
96
+ if __name__ == "__main__":
97
+ # Set up project root path
98
+ project_root = Path(__file__).resolve().parents[1]
99
+ sys.path.insert(0, str(project_root))
100
+ main()
src/db/crud.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+
3
+ from sqlalchemy.orm import Session
4
+
5
+ from . import models
6
+
7
+
8
+ def add_publics(
9
+ db: Session, publics: List[Dict[str, str]]) -> List[models.Public]:
10
+ """
11
+ Add multiple public pages to the database.
12
+
13
+ Args:
14
+ db (Session): The database session.
15
+ publics (List[Dict[str, str]]): List of public page data.
16
+
17
+ Returns:
18
+ List[models.Public]: List of added public page objects.
19
+ """
20
+ db_publics = [models.Public(**public) for public in publics]
21
+ db.add_all(db_publics)
22
+ db.commit()
23
+ for public in db_publics:
24
+ db.refresh(public)
25
+ return db_publics
26
+
27
+
28
+ def add_memes(db: Session, memes: List[Dict[str, Any]]) -> List[models.Meme]:
29
+ """
30
+ Add multiple memes to the database.
31
+
32
+ Args:
33
+ db (Session): The database session.
34
+ memes (List[Dict[str, Any]]): List of meme data.
35
+
36
+ Returns:
37
+ List[models.Meme]: List of added meme objects.
38
+ """
39
+ db_memes = [models.Meme(**meme) for meme in memes]
40
+ db.add_all(db_memes)
41
+ db.commit()
42
+ for meme in db_memes:
43
+ db.refresh(meme)
44
+ return db_memes
45
+
46
+
47
+ def get_memes_by_publics(db: Session,
48
+ public_ids: List[int]) -> List[models.Meme]:
49
+ """
50
+ Retrieve memes associated with specific public pages.
51
+
52
+ Args:
53
+ db (Session): The database session.
54
+ public_ids (List[int]): List of public page IDs.
55
+
56
+ Returns:
57
+ List[models.Meme]: List of meme objects.
58
+ """
59
+ return db.query(
60
+ models.Meme).filter(
61
+ models.Meme.public_id.in_(public_ids)).all()
62
+
63
+
64
+ def get_all_memes(db: Session) -> List[models.Meme]:
65
+ """
66
+ Retrieve all memes from the database.
67
+
68
+ Args:
69
+ db (Session): The database session.
70
+
71
+ Returns:
72
+ List[models.Meme]: List of all meme objects.
73
+ """
74
+ return db.query(models.Meme).all()
75
+
76
+
77
+ def get_all_publics(db: Session) -> List[models.Public]:
78
+ """
79
+ Retrieve all public pages from the database.
80
+
81
+ Args:
82
+ db (Session): The database session.
83
+
84
+ Returns:
85
+ List[models.Public]: List of all public page objects.
86
+ """
87
+ return db.query(models.Public).all()
88
+
89
+
90
+ def get_memes_by_ids(db: Session, meme_ids: List[int]) -> List[models.Meme]:
91
+ """
92
+ Retrieve memes by their IDs.
93
+
94
+ Args:
95
+ db (Session): The database session.
96
+ meme_ids (List[int]): List of meme IDs.
97
+
98
+ Returns:
99
+ List[models.Meme]: List of meme objects.
100
+ """
101
+ return db.query(models.Meme).filter(models.Meme.id.in_(meme_ids)).all()
102
+
103
+
104
+ def get_publics_by_ids(db: Session,
105
+ public_ids: List[int]) -> List[models.Public]:
106
+ """
107
+ Retrieve public pages by their IDs.
108
+
109
+ Args:
110
+ db (Session): The database session.
111
+ public_ids (List[int]): List of public page IDs.
112
+
113
+ Returns:
114
+ List[models.Public]: List of public page objects.
115
+ """
116
+ return db.query(
117
+ models.Public).filter(
118
+ models.Public.id.in_(public_ids)).all()
119
+
120
+
121
+ def delete_memes(db: Session, meme_ids: List[int]) -> int:
122
+ """
123
+ Delete memes by their IDs.
124
+
125
+ Args:
126
+ db (Session): The database session.
127
+ meme_ids (List[int]): List of meme IDs to delete.
128
+
129
+ Returns:
130
+ int: Number of deleted memes.
131
+ """
132
+ deleted_count = db.query(
133
+ models.Meme).filter(
134
+ models.Meme.id.in_(meme_ids)).delete(
135
+ synchronize_session='fetch')
136
+ db.commit()
137
+ return deleted_count
138
+
139
+
140
+ def delete_publics(db: Session, public_ids: List[int]) -> int:
141
+ """
142
+ Delete public pages and their associated memes.
143
+
144
+ Args:
145
+ db (Session): The database session.
146
+ public_ids (List[int]): List of public page IDs to delete.
147
+
148
+ Returns:
149
+ int: Number of deleted public pages.
150
+ """
151
+ # First, delete associated memes
152
+ db.query(
153
+ models.Meme).filter(
154
+ models.Meme.public_id.in_(public_ids)).delete(
155
+ synchronize_session='fetch')
156
+
157
+ # Then delete the publics
158
+ deleted_count = db.query(
159
+ models.Public).filter(
160
+ models.Public.id.in_(public_ids)).delete(
161
+ synchronize_session='fetch')
162
+ db.commit()
163
+ return deleted_count
164
+
165
+
166
+ def get_memes_with_public_info(
167
+ db: Session, meme_ids: List[int]) -> List[tuple[models.Meme, models.Public]]:
168
+ """
169
+ Retrieve memes with their associated public page information.
170
+
171
+ Args:
172
+ db (Session): The database session.
173
+ meme_ids (List[int]): List of meme IDs.
174
+
175
+ Returns:
176
+ List[tuple[models.Meme, models.Public]]: List of tuples containing meme and public page objects.
177
+ """
178
+ return db.query(models.Meme, models.Public).\
179
+ join(models.Public, models.Meme.public_id == models.Public.id).\
180
+ filter(models.Meme.id.in_(meme_ids)).all()
181
+
182
+
183
+ def update_memes(db: Session, meme_updates: List[Dict[str, Any]]) -> None:
184
+ """
185
+ Update multiple memes in the database.
186
+
187
+ Args:
188
+ db (Session): The database session.
189
+ meme_updates (List[Dict[str, Any]]): List of meme update data.
190
+ """
191
+ for update in meme_updates:
192
+ meme_id = update.pop('id')
193
+ db.query(models.Meme).filter(models.Meme.id == meme_id). update({getattr(
194
+ models.Meme, k): v for k, v in update.items()}, synchronize_session='fetch')
195
+ db.commit()
196
+
197
+
198
+ def update_publics(db: Session, public_updates: List[Dict[str, Any]]) -> None:
199
+ """
200
+ Update multiple public pages in the database.
201
+
202
+ Args:
203
+ db (Session): The database session.
204
+ public_updates (List[Dict[str, Any]]): List of public page update data.
205
+ """
206
+ for update in public_updates:
207
+ public_id = update.pop('id')
208
+ db.query(models.Public).filter(models.Public.id == public_id). update({getattr(
209
+ models.Public, k): v for k, v in update.items()}, synchronize_session='fetch')
210
+ db.commit()
src/db/models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, ForeignKey
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import relationship
4
+
5
+
6
+ Base = declarative_base()
7
+
8
+
9
+ class Public(Base):
10
+ __tablename__ = "publics"
11
+
12
+ id = Column(Integer, primary_key=True, index=True, autoincrement=True)
13
+ public_vk = Column(String, unique=True, index=True)
14
+ public_name = Column(String)
15
+
16
+ memes = relationship("Meme", back_populates="public")
17
+
18
+
19
+ class Meme(Base):
20
+ __tablename__ = "memes"
21
+
22
+ id = Column(Integer, primary_key=True, index=True, autoincrement=True)
23
+ public_id = Column(Integer, ForeignKey("publics.id"))
24
+ text = Column(String)
25
+ image_url = Column(String)
26
+ local_image_path = Column(String)
27
+
28
+ public = relationship("Public", back_populates="memes")
src/indexing/bm25_indexer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from typing import List, Callable
4
+
5
+ from rank_bm25 import BM25Okapi
6
+
7
+
8
+ class BM25Indexer:
9
+ def __init__(self, corpus: List[str],
10
+ tokenizer: Callable[[str], List[str]]):
11
+ """
12
+ Initialize the BM25Indexer.
13
+
14
+ Args:
15
+ corpus (List[str]): The corpus to be indexed.
16
+ tokenizer (Callable[[str], List[str]]): A function to tokenize the text.
17
+ """
18
+ self.corpus = corpus
19
+ self.tokenizer = tokenizer
20
+ self.bm25 = None
21
+
22
+ def create_index(self, save_dir: str) -> None:
23
+ """
24
+ Create and save the BM25 index.
25
+
26
+ Args:
27
+ save_dir (str): Directory to save the index.
28
+ """
29
+ # Ensure the save directory exists
30
+ os.makedirs(save_dir, exist_ok=True)
31
+
32
+ # Tokenize the corpus
33
+ tokenized_corpus = [self.tokenizer(doc) for doc in self.corpus]
34
+
35
+ # Create the BM25 model
36
+ self.bm25 = BM25Okapi(tokenized_corpus)
37
+
38
+ # Save the BM25 index
39
+ with open(os.path.join(save_dir, 'bm25_index.pkl'), 'wb') as f:
40
+ pickle.dump(self.bm25, f)
src/indexing/semantic_indexer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ class SemanticIndexer:
9
+ def __init__(
10
+ self,
11
+ corpus: List[str],
12
+ model: str,
13
+ prefix: Optional[str] = None
14
+ ):
15
+ """
16
+ Initialize the SemanticIndexer.
17
+
18
+ Args:
19
+ corpus (List[str]): The corpus to be indexed.
20
+ model (str): The name or path of the SentenceTransformer model to use.
21
+ prefix (Optional[str], optional): A prefix to add to each text in the corpus. Defaults to None.
22
+ """
23
+ self.corpus = corpus
24
+ self.model = SentenceTransformer(model)
25
+ self.prefix = prefix
26
+ self.embeddings = None
27
+
28
+ def create_index(self, save_dir: str) -> None:
29
+ """
30
+ Create and save the semantic index.
31
+
32
+ Args:
33
+ save_dir (str): Directory to save the embeddings.
34
+ """
35
+ # Ensure the save directory exists
36
+ os.makedirs(save_dir, exist_ok=True)
37
+
38
+ # Prepare texts with prefix if provided
39
+ texts = [
40
+ f"{self.prefix}{text}" if self.prefix else text for text in self.corpus]
41
+
42
+ # Create embeddings
43
+ self.embeddings = self.model.encode(
44
+ texts,
45
+ show_progress_bar=True,
46
+ convert_to_numpy=True,
47
+ normalize_embeddings=True
48
+ )
49
+
50
+ # Save embeddings
51
+ embeddings_file = os.path.join(save_dir, "embeddings.npy")
52
+ np.save(embeddings_file, self.embeddings)
src/interface.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ from omegaconf import OmegaConf
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.orm import sessionmaker
8
+
9
+
10
+ def initialize_search_engines(db, config):
11
+ """
12
+ Initialize both BM25 and Semantic search engines.
13
+
14
+ Args:
15
+ db: Database session.
16
+ config: Configuration dictionary.
17
+
18
+ Returns:
19
+ tuple: Initialized BM25Search and SemanticSearch engines.
20
+ """
21
+ from search.bm25_search import BM25Search
22
+ from search.semantic_search import SemanticSearch
23
+ from preprocessing.mystem_tokenizer import MystemTokenizer
24
+
25
+ custom_tokenizer = MystemTokenizer()
26
+ bm25_search = BM25Search(
27
+ db,
28
+ config['index_folders']['bm25'],
29
+ custom_tokenizer.tokenize
30
+ )
31
+ semantic_search = SemanticSearch(
32
+ db,
33
+ model=config['semantic_search']['model'],
34
+ embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
35
+ prefix=config['semantic_search']['query_prefix'])
36
+
37
+ return bm25_search, semantic_search
38
+
39
+
40
+ def search_memes(query: str, search_type: str, num_results: int):
41
+ """
42
+ Search for memes using the specified search method.
43
+
44
+ Args:
45
+ query (str): The search query.
46
+ search_type (str): The type of search to perform. Either 'BM25' or 'Семантический'.
47
+ num_results (int): The number of results to return.
48
+
49
+ Returns:
50
+ tuple: A tuple containing the search results and search time.
51
+ """
52
+ if search_type == "BM25":
53
+ results = bm25_search.search(query, num_results)
54
+ else:
55
+ results = semantic_search.search(query, num_results)
56
+
57
+ output = []
58
+ for result in results['results']:
59
+ output.append((result['image_url'], result['text']))
60
+
61
+ return output, f"Время поиска: {results['search_time']:.4f} секунд"
62
+
63
+
64
+ def main():
65
+ global bm25_search, semantic_search
66
+
67
+ # Load configuration
68
+ config = OmegaConf.load('config.yaml')
69
+ config = OmegaConf.to_container(config)
70
+
71
+ # Initialize database session
72
+ engine = create_engine(config['database']['url'])
73
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
74
+ db = SessionLocal()
75
+
76
+ # Initialize search engines
77
+ bm25_search, semantic_search = initialize_search_engines(db, config)
78
+
79
+ # Gradio interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# Поиск мемов")
82
+ gr.Markdown(
83
+ "Добро пожаловать в приложение для поиска мемов! Введите запрос, выберите тип поиска и количество результатов."
84
+ )
85
+
86
+ with gr.Row():
87
+ query = gr.Textbox(label="Запрос")
88
+ search_type = gr.Radio(
89
+ ["BM25", "Семантический"],
90
+ label="Тип поиска",
91
+ value="BM25"
92
+ )
93
+ num_results = gr.Slider(
94
+ minimum=1,
95
+ maximum=10,
96
+ step=1,
97
+ value=1,
98
+ label="Количество результатов"
99
+ )
100
+
101
+ search_button = gr.Button("Найти")
102
+
103
+ output_gallery = gr.Gallery(
104
+ label="Результаты",
105
+ show_label=False,
106
+ columns=3,
107
+ height=400
108
+ )
109
+ output_time = gr.Markdown()
110
+
111
+ search_button.click(
112
+ fn=search_memes,
113
+ inputs=[query, search_type, num_results],
114
+ outputs=[output_gallery, output_time]
115
+ )
116
+
117
+ demo.launch()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ # Set up project root path
122
+ project_root = Path(__file__).resolve().parents[1]
123
+ sys.path.insert(0, str(project_root))
124
+ main()
src/main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import fire
5
+ from omegaconf import OmegaConf
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.orm import sessionmaker
8
+
9
+
10
+ def initialize_bm25_search(db, config):
11
+ """
12
+ Initialize BM25 search engine.
13
+
14
+ Args:
15
+ db: Database session.
16
+ config: Configuration dictionary.
17
+
18
+ Returns:
19
+ BM25Search: Initialized BM25 search engine.
20
+ """
21
+ from search.bm25_search import BM25Search
22
+ from preprocessing.mystem_tokenizer import MystemTokenizer
23
+
24
+ custom_tokenizer = MystemTokenizer()
25
+ return BM25Search(
26
+ db,
27
+ config['index_folders']['bm25'],
28
+ custom_tokenizer.tokenize
29
+ )
30
+
31
+
32
+ def initialize_semantic_search(db, config):
33
+ """
34
+ Initialize semantic search engine.
35
+
36
+ Args:
37
+ db: Database session.
38
+ config: Configuration dictionary.
39
+
40
+ Returns:
41
+ SemanticSearch: Initialized semantic search engine.
42
+ """
43
+ from search.semantic_search import SemanticSearch
44
+ return SemanticSearch(
45
+ db,
46
+ model=config['semantic_search']['model'],
47
+ embeddings_file=f"{config['index_folders']['semantic']}/embeddings.npy",
48
+ prefix=config['semantic_search']['query_prefix'])
49
+
50
+
51
+ def search_memes(query: str, search_type: str = 'bm25', num: int = 1):
52
+ """
53
+ Search for memes using the specified search method.
54
+
55
+ Args:
56
+ query (str): The search query.
57
+ search_type (str): The type of search to perform. Either 'bm25' or 'semantic'. Defaults to 'bm25s'.
58
+ num (int): The number of results to return. Defaults to 1.
59
+
60
+ Returns:
61
+ None: Prints the results to the console.
62
+ """
63
+ if not query:
64
+ print("Error: Query is required.")
65
+ return
66
+ if search_type not in ['bm25', 'semantic']:
67
+ print("Error: Invalid search type. Use 'bm25' or 'semantic'.")
68
+ return
69
+ if num < 1:
70
+ print("Error: Number of results must be at least 1.")
71
+ return
72
+
73
+ # Load configuration
74
+ config = OmegaConf.load('config.yaml')
75
+ config = OmegaConf.to_container(config)
76
+
77
+ # Initialize database session
78
+ engine = create_engine(config['database']['url'])
79
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
80
+ db = SessionLocal()
81
+
82
+ try:
83
+ # Initialize search engine
84
+ if search_type == 'bm25':
85
+ search_engine = initialize_bm25_search(db, config)
86
+ elif search_type == 'semantic':
87
+ search_engine = initialize_semantic_search(db, config)
88
+
89
+ # Perform search
90
+ results = search_engine.search(query, num)
91
+
92
+ # Print results
93
+ for result in results['results']:
94
+ print(result['text'])
95
+ print(f"\nSearch time: {results['search_time']:.4f} seconds")
96
+ finally:
97
+ db.close()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ # Set up project root path
102
+ project_root = Path(__file__).resolve().parents[1]
103
+ sys.path.insert(0, str(project_root))
104
+ fire.Fire(search_memes)
src/parsing/vk_meme_parser.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Dict, Any
3
+ from urllib.parse import urlparse
4
+
5
+ import requests
6
+ import vk_api
7
+
8
+
9
+ class VKMemeParser:
10
+ def __init__(self, token: str):
11
+ """
12
+ Initialize the VK Meme Parser.
13
+
14
+ Args:
15
+ token (str): VK API access token.
16
+ """
17
+ self.vk_session = vk_api.VkApi(token=token)
18
+ self.vk = self.vk_session.get_api()
19
+
20
+ def _process_post(self, post: Dict[str, Any]) -> Optional[Dict[str, Any]]:
21
+ """
22
+ Process a single post and extract relevant information.
23
+
24
+ Args:
25
+ post (Dict[str, Any]): A dictionary containing post data.
26
+
27
+ Returns:
28
+ Optional[Dict[str, Any]]: A dictionary with post ID, text, and image URL if valid,
29
+ None otherwise.
30
+ """
31
+ # Check if the post is valid
32
+ if (post.get("marked_as_ads") or
33
+ "is_pinned" in post or
34
+ "copy_history" in post or
35
+ len(post.get("attachments", [])) != 1 or
36
+ post["attachments"][0]["type"] != "photo"):
37
+ return None
38
+
39
+ post_id = post["id"]
40
+ text = post["text"].strip()
41
+
42
+ # Get the largest available photo
43
+ photo_sizes = post["attachments"][0]["photo"]["sizes"]
44
+ largest_photo = max(
45
+ photo_sizes,
46
+ key=lambda x: x["width"] * x["height"])
47
+ image_url = largest_photo["url"]
48
+
49
+ return {
50
+ "id": post_id,
51
+ "text": text,
52
+ "image_url": image_url
53
+ }
54
+
55
+ def get_memes(self, public_id: str) -> Dict[str, Any]:
56
+ """
57
+ Retrieve and process all meme posts from a specified public page.
58
+
59
+ Args:
60
+ public_id (str): ID or short name of the public page.
61
+
62
+ Returns:
63
+ Dict[str, Any]: A dictionary containing the public's name and processed meme posts.
64
+ """
65
+ memes = []
66
+
67
+ # Determine whether to use domain or owner_id
68
+ if public_id.isdigit() or (public_id.startswith("-")
69
+ and public_id[1:].isdigit()):
70
+ params: Dict[str, Any] = {"owner_id": int(public_id)}
71
+ else:
72
+ params: Dict[str, Any] = {"domain": public_id}
73
+
74
+ # Fetch public's name
75
+ group_info = self.vk.groups.getById(group_id=public_id)[0]
76
+ group_name = group_info['name']
77
+
78
+ # Process posts
79
+ offset = 0
80
+ while True:
81
+ # Fetch 100 posts at a time
82
+ params["count"] = 100
83
+ params["offset"] = offset
84
+ response = self.vk.wall.get(**params)
85
+
86
+ posts = response["items"]
87
+
88
+ for post in posts:
89
+ processed_post = self._process_post(post)
90
+ if processed_post:
91
+ memes.append(processed_post)
92
+
93
+ # Check if we've reached the end of posts
94
+ if len(posts) < 100:
95
+ break
96
+
97
+ offset = response["next_from"]
98
+
99
+ return {
100
+ "name": group_name,
101
+ "posts": memes
102
+ }
103
+
104
+ def download_image(
105
+ self,
106
+ image_url: str,
107
+ folder_path: str) -> Optional[str]:
108
+ """
109
+ Download an image from the given URL and save it to the specified folder.
110
+
111
+ Args:
112
+ image_url (str): The URL of the image to download.
113
+ folder_path (str): The path to the folder where the image should be saved.
114
+
115
+ Returns:
116
+ Optional[str]: The path to the saved image file, or None if the download failed.
117
+ """
118
+ try:
119
+ # Create the folder if it doesn't exist
120
+ os.makedirs(folder_path, exist_ok=True)
121
+
122
+ filename = os.path.basename(urlparse(image_url).path)
123
+ if not os.path.splitext(filename)[1]:
124
+ return None
125
+
126
+ image_path = os.path.join(folder_path, filename)
127
+
128
+ response = requests.get(image_url, stream=True)
129
+ response.raise_for_status() # Raise an exception for bad status codes
130
+
131
+ with open(image_path, 'wb') as file:
132
+ for chunk in response.iter_content(chunk_size=8192):
133
+ file.write(chunk)
134
+
135
+ return filename
136
+
137
+ except Exception as e:
138
+ print(f"Error downloading image from {image_url}: {str(e)}")
139
+ return None
src/preprocessing/__pycache__/mystem_tokenizer.cpython-311.pyc ADDED
Binary file (2.47 kB). View file
 
src/preprocessing/mystem_tokenizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import nltk
4
+ from pymystem3 import Mystem
5
+
6
+
7
+ class MystemTokenizer:
8
+ def __init__(self, stopwords: Union[List[str], str] = "ru"):
9
+ """
10
+ Initialize the MystemTokenizer.
11
+
12
+ Args:
13
+ stopwords (Union[List[str], str]): Either a list of stopwords or "ru" for Russian stopwords.
14
+ """
15
+ if stopwords == "ru":
16
+ try:
17
+ self.stopwords = nltk.corpus.stopwords.words("russian")
18
+ except LookupError:
19
+ # Download stopwords if not available
20
+ nltk.download("stopwords")
21
+ self.stopwords = nltk.corpus.stopwords.words("russian")
22
+ else:
23
+ self.stopwords = stopwords
24
+
25
+ self.mystem = Mystem()
26
+
27
+ def tokenize(self, text: str) -> List[str]:
28
+ """
29
+ Tokenize and lemmatize the input text, removing stopwords.
30
+
31
+ Args:
32
+ text (str): The input text to tokenize.
33
+
34
+ Returns:
35
+ List[str]: A list of lemmatized tokens.
36
+ """
37
+ # Lemmatize and tokenize using Mystem
38
+ lemmas = self.mystem.lemmatize(text.lower())
39
+
40
+ # Filter out non-letter tokens and stopwords
41
+ tokens = [
42
+ lemma for lemma in lemmas
43
+ if lemma.isalpha() and lemma not in self.stopwords
44
+ ]
45
+
46
+ return tokens
src/search/bm25_search.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import time
4
+ from typing import List, Dict, Any, Callable
5
+
6
+ import numpy as np
7
+ from sqlalchemy.orm import Session
8
+ from rank_bm25 import BM25Okapi
9
+
10
+ from src.db import crud
11
+
12
+
13
+ class BM25Search:
14
+ def __init__(
15
+ self,
16
+ db: Session,
17
+ index_folder: str,
18
+ tokenizer: Callable[[str], List[str]]
19
+ ):
20
+ """
21
+ Initialize the BM25Search.
22
+
23
+ Args:
24
+ db (Session): The database session.
25
+ index_folder (str): The folder containing the BM25 index.
26
+ tokenizer (Callable[[str], List[str]]): A function to tokenize the text.
27
+ """
28
+ self.db = db
29
+ self.tokenizer = tokenizer
30
+ self.bm25 = self._load_index(index_folder)
31
+
32
+ def _load_index(self, index_folder: str) -> BM25Okapi:
33
+ """
34
+ Load the BM25 index from a file.
35
+
36
+ Args:
37
+ index_folder (str): The folder containing the BM25 index.
38
+
39
+ Returns:
40
+ BM25Okapi: The loaded BM25 index.
41
+ """
42
+ with open(os.path.join(index_folder, 'bm25_index.pkl'), 'rb') as f:
43
+ return pickle.load(f)
44
+
45
+ def search(self, query: str, n: int = 3) -> Dict[str, Any]:
46
+ """
47
+ Perform a search using BM25.
48
+
49
+ Args:
50
+ query (str): The search query.
51
+ n (int, optional): The number of results to return. Defaults to 3.
52
+
53
+ Returns:
54
+ Dict[str, Any]: A dictionary containing search results and search time.
55
+ """
56
+ start_time = time.time()
57
+
58
+ # Tokenize the query
59
+ query_tokens = self.tokenizer(query)
60
+
61
+ # Retrieve scores for all documents
62
+ scores = self.bm25.get_scores(query_tokens)
63
+
64
+ # Get top n document indices
65
+ top_n_indices = np.argsort(scores)[-n:][::-1]
66
+ top_n_scores = scores[top_n_indices]
67
+
68
+ # Adjust indices to match database IDs (assuming IDs start from 1)
69
+ db_ids = top_n_indices + 1
70
+
71
+ # Retrieve memes from the database
72
+ memes = crud.get_memes_by_ids(self.db, db_ids.tolist())
73
+
74
+ # Format the results
75
+ results = [
76
+ {
77
+ "id": meme.id,
78
+ "public_id": meme.public_id,
79
+ "text": meme.text,
80
+ "image_url": meme.image_url,
81
+ "score": top_n_scores[db_ids.tolist().index(meme.id)]
82
+ }
83
+ for meme in memes
84
+ ]
85
+
86
+ return {
87
+ "results": results,
88
+ "search_time": time.time() - start_time
89
+ }
src/search/semantic_search.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Dict, Any, Optional
3
+
4
+ import numpy as np
5
+ from sqlalchemy.orm import Session
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from src.db import crud
9
+
10
+
11
+ class SemanticSearch:
12
+ def __init__(
13
+ self,
14
+ db: Session,
15
+ model: str,
16
+ embeddings_file: str,
17
+ prefix: Optional[str] = None
18
+ ):
19
+ """
20
+ Initialize the SemanticSearch.
21
+
22
+ Args:
23
+ db (Session): The database session.
24
+ model (str): The name or path of the SentenceTransformer model to use.
25
+ embeddings_file (str): Path to the file containing pre-computed embeddings.
26
+ prefix (Optional[str], optional): A prefix to add to each query. Defaults to None.
27
+ """
28
+ self.db = db
29
+ self.model = SentenceTransformer(model)
30
+ self.prefix = prefix
31
+ self.embeddings = np.load(embeddings_file)
32
+
33
+ def search(self, query: str, n: int = 3) -> Dict[str, Any]:
34
+ """
35
+ Perform a semantic search.
36
+
37
+ Args:
38
+ query (str): The search query.
39
+ n (int, optional): The number of results to return. Defaults to 3.
40
+
41
+ Returns:
42
+ Dict[str, Any]: A dictionary containing search results and search time.
43
+ """
44
+ start_time = time.time()
45
+
46
+ # Prepare query with prefix if provided
47
+ query_text = f"{self.prefix}{query}" if self.prefix else query
48
+
49
+ # Encode the query
50
+ query_embedding = self.model.encode(
51
+ [query_text],
52
+ convert_to_numpy=True,
53
+ normalize_embeddings=True
54
+ )[0]
55
+
56
+ # Compute similarity scores
57
+ scores = np.dot(self.embeddings, query_embedding)
58
+
59
+ # Get top n results
60
+ top_n_indices = np.argsort(scores)[-n:][::-1]
61
+ top_n_scores = scores[top_n_indices]
62
+
63
+ # Adjust indices to match database IDs
64
+ db_ids = top_n_indices + 1
65
+
66
+ # Retrieve memes from the database
67
+ memes = crud.get_memes_by_ids(self.db, db_ids.tolist())
68
+
69
+ # Format the results
70
+ results = [
71
+ {
72
+ "id": meme.id,
73
+ "public_id": meme.public_id,
74
+ "text": meme.text,
75
+ "image_url": meme.image_url,
76
+ "local_image_path": meme.local_image_path,
77
+ "score": top_n_scores[db_ids.tolist().index(meme.id)]
78
+ }
79
+ for meme in memes
80
+ ]
81
+
82
+ return {
83
+ "results": results,
84
+ "search_time": time.time() - start_time
85
+ }