sradc commited on
Commit
1801c3b
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data
2
+
3
+ ### Python template
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ .idea/
164
+
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
3
+
4
+ [theme]
5
+ base="dark"
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Visual Content Search Over Videos
3
+ emoji: 🐢
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: video_semantic_search/app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # semvideo-hackathon-230521
13
+
14
+ ## Project Description
15
+
16
+ This project let's you search YouTube videos using a text string. The search is done over the actual video frames,
17
+ rather than any associated text. The search results are displayed as a list of videos, with the most relevant video
18
+ shown first. The user can then click on any of the videos to play it.
19
+
20
+ ## Quick Start
21
+
22
+ Run the following commands to get started:
23
+
24
+ ```bash
25
+ git clone https://github.com/sradc/semvideo-hackathon-230521.git
26
+ cd semvideo-hackathon-230521
27
+ poetry install
28
+ PYTHONPATH=. poetry run streamlit run video_semantic_search/app.py
29
+ ```
30
+
31
+ If you do not have `poetry` installed, refer to the [poetry documentation](https://python-poetry.org/docs/#installation).
_dev/clip.ipynb ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 33,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from typing import List\n",
10
+ "import requests\n",
11
+ "from PIL import Image\n",
12
+ "from transformers import CLIPModel, CLIPProcessor, CLIPFeatureExtractor\n",
13
+ "import torch"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 41,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
23
+ "image = Image.open(requests.get(url, stream=True).raw)"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "class ClipWrapper:\n",
33
+ " def __init__(self):\n",
34
+ " self.model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
35
+ " self.processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
36
+ "\n",
37
+ " def images2vec(self, images: List[Image.Image]) -> torch.Tensor:\n",
38
+ " inputs = self.processor(images=images, return_tensors=\"pt\")\n",
39
+ " with torch.no_grad():\n",
40
+ " model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}\n",
41
+ " image_embeds = self.model.vision_model(**model_inputs)\n",
42
+ " clip_vectors = self.model.visual_projection(image_embeds[1])\n",
43
+ " return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)\n",
44
+ "\n",
45
+ " def texts2vec(self, texts: List[str]) -> torch.Tensor:\n",
46
+ " inputs = self.processor(text=texts, return_tensors=\"pt\", padding=True)\n",
47
+ " with torch.no_grad():\n",
48
+ " model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}\n",
49
+ " text_embeds = self.model.text_model(**model_inputs)\n",
50
+ " text_vectors = self.model.text_projection(text_embeds[1])\n",
51
+ " return text_vectors / text_vectors.norm(dim=-1, keepdim=True)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 42,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
61
+ "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 65,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "data": {
71
+ "text/plain": [
72
+ "torch.Size([2, 512])"
73
+ ]
74
+ },
75
+ "execution_count": 65,
76
+ "metadata": {},
77
+ "output_type": "execute_result"
78
+ }
79
+ ],
80
+ "source": [
81
+ "def images2vec(images: List[Image.Image]) -> torch.Tensor:\n",
82
+ " inputs = processor(images=images, return_tensors=\"pt\")\n",
83
+ " with torch.no_grad():\n",
84
+ " model_inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
85
+ " image_embeds = model.vision_model(**model_inputs)\n",
86
+ " clip_vectors = model.visual_projection(image_embeds[1])\n",
87
+ " return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)\n",
88
+ "\n",
89
+ "\n",
90
+ "result = images2vec([image, image])\n",
91
+ "result.shape"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 70,
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "data": {
101
+ "text/plain": [
102
+ "torch.Size([2, 512])"
103
+ ]
104
+ },
105
+ "execution_count": 70,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "def texts2vec(texts: List[str]) -> torch.Tensor:\n",
112
+ " inputs = processor(text=texts, return_tensors=\"pt\", padding=True)\n",
113
+ " with torch.no_grad():\n",
114
+ " model_inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
115
+ " text_embeds = model.text_model(**model_inputs)\n",
116
+ " text_vectors = model.text_projection(text_embeds[1])\n",
117
+ " return text_vectors / text_vectors.norm(dim=-1, keepdim=True)\n",
118
+ "\n",
119
+ "\n",
120
+ "texts2vec([\"a photo of a cat\", \"a photo of a dog\"]).shape"
121
+ ]
122
+ }
123
+ ],
124
+ "metadata": {
125
+ "kernelspec": {
126
+ "display_name": "semvideo-hackathon-230523",
127
+ "language": "python",
128
+ "name": "python3"
129
+ },
130
+ "language_info": {
131
+ "codemirror_mode": {
132
+ "name": "ipython",
133
+ "version": 3
134
+ },
135
+ "file_extension": ".py",
136
+ "mimetype": "text/x-python",
137
+ "name": "python",
138
+ "nbconvert_exporter": "python",
139
+ "pygments_lexer": "ipython3",
140
+ "version": "3.9.16"
141
+ },
142
+ "orig_nbformat": 4
143
+ },
144
+ "nbformat": 4,
145
+ "nbformat_minor": 2
146
+ }
_dev/download_videos.ipynb ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 22,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import subprocess\n",
10
+ "from pathlib import Path\n",
11
+ "import re\n",
12
+ "\n",
13
+ "from tqdm import tqdm\n",
14
+ "\n",
15
+ "import pipeline.videos as videos"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 23,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "100%|██████████| 1/1 [00:02<00:00, 2.37s/it]\n"
28
+ ]
29
+ }
30
+ ],
31
+ "source": [
32
+ "VIDEO_DIR = Path(\"videos\")\n",
33
+ "VIDEO_DIR.mkdir(exist_ok=True)\n",
34
+ "(VIDEO_DIR / \".gitingore\").write_text(\"**\")\n",
35
+ "\n",
36
+ "video_urls = [\"https://www.youtube.com/watch?v=frYIj2FGmMA&foo=bar\"]\n",
37
+ "\n",
38
+ "\n",
39
+ "def get_id(url: str) -> str:\n",
40
+ " return re.search(r\"(?<=v=)[^&]+\", url).group(0)\n",
41
+ "\n",
42
+ "\n",
43
+ "for video_url in tqdm(video_urls):\n",
44
+ " video_id = get_id(video_url)\n",
45
+ " video_path = VIDEO_DIR / f\"{video_id}.mp4\"\n",
46
+ " if video_path.exists():\n",
47
+ " print(f\"Skipping {video_path} because it already exists\")\n",
48
+ " continue\n",
49
+ " subprocess.run([\"yt-dlp\", \"--quiet\", \"-f\", \"133\", \"-o\", str(video_path), video_url])\n",
50
+ "\n",
51
+ "# get_id(video_urls[0])\n",
52
+ "# # !yt-dlp -f 133 -o \"buster.mp4\" {video_url}\n",
53
+ "# def download_video(video_url: str) -> None:\n",
54
+ "# subprocess.run(['yt-dlp', '-f', '133', '-o', 'buster.mp4', video_url])"
55
+ ]
56
+ }
57
+ ],
58
+ "metadata": {
59
+ "kernelspec": {
60
+ "display_name": "semvideo-hackathon-230523",
61
+ "language": "python",
62
+ "name": "python3"
63
+ },
64
+ "language_info": {
65
+ "codemirror_mode": {
66
+ "name": "ipython",
67
+ "version": 3
68
+ },
69
+ "file_extension": ".py",
70
+ "mimetype": "text/x-python",
71
+ "name": "python",
72
+ "nbconvert_exporter": "python",
73
+ "pygments_lexer": "ipython3",
74
+ "version": "3.9.16"
75
+ },
76
+ "orig_nbformat": 4
77
+ },
78
+ "nbformat": 4,
79
+ "nbformat_minor": 2
80
+ }
_dev/process_video.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/Users/sidneyradcliffe/miniforge3/envs/semvideo-hackathon-230523/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from tqdm import tqdm\n",
19
+ "\n",
20
+ "import pandas as pd\n",
21
+ "import cv2\n",
22
+ "from PIL import Image\n",
23
+ "import numpy as np\n",
24
+ "\n",
25
+ "from pipeline.clip_wrapper import ClipWrapper, MODEL_DIM\n",
26
+ "from pipeline.download_videos import VIDEO_DIR, REPO_ROOT, DATA_DIR\n",
27
+ "\n",
28
+ "FRAME_EXTRACT_RATE_SECONDS = 5 # Extract a frame every 5 seconds\n",
29
+ "IMAGES_DIR = DATA_DIR / \"images\"\n",
30
+ "\n",
31
+ "DATAFRAME_PATH = DATA_DIR / \"dataset.parquet\""
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "clip_wrapper = ClipWrapper()"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 3,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "def get_clip_vectors(video_path):\n",
50
+ " cap = cv2.VideoCapture(str(video_path))\n",
51
+ " num_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
52
+ " fps = int(cap.get(cv2.CAP_PROP_FPS))\n",
53
+ " extract_every_n_frames = FRAME_EXTRACT_RATE_SECONDS * fps\n",
54
+ " for frame_idx in tqdm(range(num_video_frames), desc=\"Running CLIP on video\"):\n",
55
+ " ret, frame = cap.read()\n",
56
+ " if frame_idx % extract_every_n_frames != 0:\n",
57
+ " continue\n",
58
+ " image = Image.fromarray(frame[..., ::-1])\n",
59
+ " clip_vector = clip_wrapper.images2vec([image]).squeeze().numpy()\n",
60
+ " timestamp_secs = frame_idx / fps\n",
61
+ " yield clip_vector, image, timestamp_secs, frame_idx\n",
62
+ " cap.release()"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 4,
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Running CLIP on video: 100%|██████████| 7465/7465 [00:04<00:00, 1759.86it/s]\n",
75
+ "Running CLIP on video: 100%|██████████| 6056/6056 [00:03<00:00, 1728.62it/s]\n",
76
+ "Running CLIP on video: 100%|██████████| 5234/5234 [00:03<00:00, 1648.12it/s]\n",
77
+ "Running CLIP on video: 100%|██████████| 3551/3551 [00:01<00:00, 1806.30it/s]\n",
78
+ "Running CLIP on video: 100%|██████████| 5904/5904 [00:03<00:00, 1655.01it/s]\n",
79
+ "Processing videos: 100%|██████████| 5/5 [00:16<00:00, 3.30s/it]"
80
+ ]
81
+ },
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Saving data to /Users/sidneyradcliffe/repos/semvideo-hackathon-230523/data/dataset.parquet\n"
87
+ ]
88
+ },
89
+ {
90
+ "name": "stderr",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "\n"
94
+ ]
95
+ }
96
+ ],
97
+ "source": [
98
+ "results = []\n",
99
+ "for i, video_path in enumerate(\n",
100
+ " tqdm(list(VIDEO_DIR.glob(\"*.mp4\")), desc=\"Processing videos\")\n",
101
+ "):\n",
102
+ " video_id = video_path.stem\n",
103
+ " extracted_images_dir = IMAGES_DIR / video_id\n",
104
+ " extracted_images_dir.mkdir(exist_ok=True, parents=True)\n",
105
+ " for clip_vector, image, timestamp_secs, frame_idx in get_clip_vectors(video_path):\n",
106
+ " image_path = extracted_images_dir / f\"{frame_idx}.jpg\"\n",
107
+ " image.save(image_path)\n",
108
+ " results.append(\n",
109
+ " [\n",
110
+ " video_id,\n",
111
+ " frame_idx,\n",
112
+ " timestamp_secs,\n",
113
+ " str(image_path.relative_to(REPO_ROOT)),\n",
114
+ " *clip_vector,\n",
115
+ " ]\n",
116
+ " )\n",
117
+ "df = pd.DataFrame(\n",
118
+ " results,\n",
119
+ " columns=[\"video_id\", \"frame_idx\", \"timestamp\", \"image_path\"]\n",
120
+ " + [f\"dim_{i}\" for i in range(MODEL_DIM)],\n",
121
+ ")\n",
122
+ "print(f\"Saving data to {DATAFRAME_PATH}\")\n",
123
+ "df.to_parquet(DATAFRAME_PATH, index=False)"
124
+ ]
125
+ }
126
+ ],
127
+ "metadata": {
128
+ "kernelspec": {
129
+ "display_name": "semvideo-hackathon-230523",
130
+ "language": "python",
131
+ "name": "python3"
132
+ },
133
+ "language_info": {
134
+ "codemirror_mode": {
135
+ "name": "ipython",
136
+ "version": 3
137
+ },
138
+ "file_extension": ".py",
139
+ "mimetype": "text/x-python",
140
+ "name": "python",
141
+ "nbconvert_exporter": "python",
142
+ "pygments_lexer": "ipython3",
143
+ "version": "3.9.16"
144
+ },
145
+ "orig_nbformat": 4
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 2
149
+ }
_dev/run_search_over_videos.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
activate ADDED
@@ -0,0 +1 @@
 
 
1
+ conda activate semvideo-hackathon-230523
example.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from transformers import CLIPModel, CLIPProcessor
4
+
5
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
6
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
7
+
8
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
9
+ image = Image.open(requests.get(url, stream=True).raw)
10
+
11
+ inputs = processor(
12
+ text=["a photo of a cat", "a photo of a dog"],
13
+ images=image,
14
+ return_tensors="pt",
15
+ padding=True,
16
+ )
17
+
18
+ outputs = model(**inputs)
19
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
20
+ probs = logits_per_image.softmax(
21
+ dim=1
22
+ ) # we can take the softmax to get the label probabilities
23
+ print(probs)
pipeline/clip_wrapper.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import CLIPModel, CLIPProcessor
6
+
7
+ MODEL_DIM = 512
8
+
9
+
10
+ class ClipWrapper:
11
+ def __init__(self):
12
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+
15
+ def images2vec(self, images: List[Image.Image]) -> torch.Tensor:
16
+ inputs = self.processor(images=images, return_tensors="pt")
17
+ with torch.no_grad():
18
+ model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
19
+ image_embeds = self.model.vision_model(**model_inputs)
20
+ clip_vectors = self.model.visual_projection(image_embeds[1])
21
+ return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)
22
+
23
+ def texts2vec(self, texts: List[str]) -> torch.Tensor:
24
+ inputs = self.processor(text=texts, return_tensors="pt", padding=True)
25
+ with torch.no_grad():
26
+ model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
27
+ text_embeds = self.model.text_model(**model_inputs)
28
+ text_vectors = self.model.text_projection(text_embeds[1])
29
+ return text_vectors / text_vectors.norm(dim=-1, keepdim=True)
pipeline/download_videos.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import subprocess
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from tqdm import tqdm
7
+
8
+ REPO_ROOT = Path(__file__).parents[1].resolve()
9
+ DATA_DIR = REPO_ROOT / "data"
10
+ VIDEO_DIR = DATA_DIR / "videos"
11
+ VIDEO_ID_FOLDER = DATA_DIR / "ids"
12
+
13
+
14
+ def get_id(url: str) -> str:
15
+ return re.search(r"(?<=v=)[^&]+", url).group(0)
16
+
17
+
18
+ def download_videos(video_ids: List[str]) -> None:
19
+ VIDEO_DIR.mkdir(exist_ok=True, parents=True)
20
+ for video_id in tqdm(video_ids):
21
+ video_url = f"https://www.youtube.com/watch?v={video_id}"
22
+ video_path = VIDEO_DIR / f"{video_id}.mp4"
23
+ if video_path.exists():
24
+ print(f"Skipping {video_path} because it already exists")
25
+ continue
26
+ subprocess.run(
27
+ ["yt-dlp", "--quiet", "-f", "135", "-o", str(video_path), video_url]
28
+ )
29
+
30
+
31
+ if __name__ == "__main__":
32
+ print("Downloading videos...")
33
+ ids = set()
34
+ for file in VIDEO_ID_FOLDER.glob("*.txt"):
35
+ ids.update(
36
+ [x for x in file.read_text().strip().splitlines(keepends=False) if x]
37
+ )
38
+ download_videos(ids)
pipeline/get_video_ids.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Final, Optional
6
+
7
+ import youtube_dl
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s - %(levelname)s - %(message)s",
12
+ datefmt="%Y-%m-%d %H:%M:%S",
13
+ )
14
+
15
+ PLAYLIST_URLS = [
16
+ "https://www.youtube.com/playlist?list=PL6Lt9p1lIRZ311J9ZHuzkR5A3xesae2pk", # 570, Alternative rock of the 2000s (2000-2009)
17
+ "https://www.youtube.com/playlist?list=PLMC9KNkIncKtGvr2kFRuXBVmBev6cAJ2u", # 250, Best Pop Music Videos - Top Pop Hits Playlist
18
+ "https://www.youtube.com/playlist?list=PLmXxqSJJq-yXrCPGIT2gn8b34JjOrl4Xf", # 184, 80s Music Hits | Best 80s Music Playlist
19
+ "https://www.youtube.com/playlist?list=PL7DA3D097D6FDBC02", # 150, 90's Hits - Greatest 1990's Music Hits (Best 90’s Songs Playlist)
20
+ "https://www.youtube.com/playlist?list=PLeDakahyfrO-4kuBioL5ZAoy4j6aCnzWy", # 100, Best Music Videos of All Time
21
+ "https://www.youtube.com/playlist?list=PLMC9KNkIncKtPzgY-5rmhvj7fax8fdxoj", # 200, Pop Music Playlist - Timeless Pop Songs (Updated Weekly 2023)
22
+ "https://www.youtube.com/playlist?list=PLkqz3S84Tw-RfPS9HHi3MRmrinOBKxIr8", # 82, Top POP Hits 2022 – Biggest Pop Music Videos - Vevo
23
+ "https://www.youtube.com/playlist?list=PLyORnIW1xT6wqvszJbCdLdSjylYMf3sNZ", # 100, Top 100 Music Videos 2023 - Best Music Videos 2023
24
+ "https://www.youtube.com/playlist?list=PL1Mmsa-U48mea1oIN-Eus78giJANx4D9W", # 119, 90s Music Videos
25
+ "https://www.youtube.com/playlist?list=PLurPBtLcqJqcg3r-HOhR3LZ0aDxpI15Fa", # 100, 100 Best Music Videos Of The Decade: 2010 - 2019
26
+ "https://www.youtube.com/playlist?list=PLCQCtoOJpI_A5oktQImEdDBJ50BqHXujj", # 495, MTV Classic 2000's music videos (US Version)
27
+ ]
28
+ URL_FILE: Final[Optional[str]] = os.environ.get("URL_FILE")
29
+ OUTPUT_DIR: Final[str] = os.environ.get("OUTPUT_DIR", "data/ids")
30
+
31
+
32
+ def get_all_video_ids(channel_url: str) -> list[str]:
33
+ """Get all video IDs from a YouTube channel or playlist URL.
34
+
35
+ Args:
36
+ channel_url (str): URL of the YouTube channel or playlist.
37
+
38
+ Returns:
39
+ list[str]: List of video IDs.
40
+
41
+ Notes:
42
+ If you want the videos from a channel, make sure to pass the `/videos` endpoint of the channel.
43
+ """
44
+ ydl_opts = {
45
+ "ignoreerrors": True,
46
+ "extract_flat": "in_playlist",
47
+ "dump_single_json": True,
48
+ "quiet": True,
49
+ }
50
+
51
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
52
+ playlist_info = ydl.extract_info(channel_url, download=False)
53
+ video_ids = [video["id"] for video in playlist_info["entries"] if "id" in video]
54
+
55
+ return video_ids
56
+
57
+
58
+ def process_youtube_url(url: str):
59
+ logging.info(f"Processing {url}")
60
+ ids = get_all_video_ids(url)
61
+
62
+ output_dir = Path(OUTPUT_DIR)
63
+ output_dir.mkdir(parents=True, exist_ok=True)
64
+
65
+ output = "\n".join(ids)
66
+ output_path = output_dir / f"{hashlib.md5(output.encode()).hexdigest()}.txt"
67
+ logging.info(f"Writing {len(ids)} video IDs to {output_path}")
68
+ with output_path.open(mode="w") as f:
69
+ f.write(output)
70
+
71
+
72
+ def main():
73
+ logging.info(f"Processing {len(PLAYLIST_URLS)} URLs")
74
+ for url in PLAYLIST_URLS:
75
+ process_youtube_url(url)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
pipeline/process_videos.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+ from pipeline.clip_wrapper import MODEL_DIM, ClipWrapper
7
+ from pipeline.download_videos import DATA_DIR, REPO_ROOT, VIDEO_DIR
8
+
9
+ FRAME_EXTRACT_RATE_SECONDS = 5 # Extract a frame every 5 seconds
10
+ IMAGES_DIR = DATA_DIR / "images"
11
+ DATAFRAME_PATH = DATA_DIR / "dataset.parquet"
12
+
13
+
14
+ def process_videos() -> None:
15
+ "Runs clip on video frames, saves results to a parquet file"
16
+ clip_wrapper = ClipWrapper()
17
+ results = []
18
+ for video_path in tqdm(list(VIDEO_DIR.glob("*.mp4")), desc="Processing videos"):
19
+ video_id = video_path.stem
20
+ extracted_images_dir = IMAGES_DIR / video_id
21
+ extracted_images_dir.mkdir(exist_ok=True, parents=True)
22
+ complete_file = extracted_images_dir / "complete"
23
+ if complete_file.exists():
24
+ continue
25
+ for clip_vector, image, timestamp_secs, frame_idx in get_clip_vectors(
26
+ video_path, clip_wrapper
27
+ ):
28
+ image_path = extracted_images_dir / f"{frame_idx}.jpg"
29
+ image.save(image_path)
30
+ results.append(
31
+ [
32
+ video_id,
33
+ frame_idx,
34
+ timestamp_secs,
35
+ str(image_path.relative_to(REPO_ROOT)),
36
+ *clip_vector,
37
+ ]
38
+ )
39
+ complete_file.touch()
40
+ df = pd.DataFrame(
41
+ results,
42
+ columns=["video_id", "frame_idx", "timestamp", "image_path"]
43
+ + [f"dim_{i}" for i in range(MODEL_DIM)],
44
+ )
45
+ print(f"Saving data to {DATAFRAME_PATH}")
46
+ df.to_parquet(DATAFRAME_PATH, index=False)
47
+
48
+
49
+ def get_clip_vectors(video_path, clip_wrapper):
50
+ cap = cv2.VideoCapture(str(video_path))
51
+ num_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
52
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
53
+ extract_every_n_frames = FRAME_EXTRACT_RATE_SECONDS * fps
54
+ for frame_idx in tqdm(range(num_video_frames), desc="Running CLIP on video"):
55
+ ret, frame = cap.read()
56
+ if frame_idx % extract_every_n_frames != 0:
57
+ continue
58
+ image = Image.fromarray(frame[..., ::-1])
59
+ clip_vector = clip_wrapper.images2vec([image]).squeeze().numpy()
60
+ timestamp_secs = frame_idx / fps
61
+ yield clip_vector, image, timestamp_secs, frame_idx
62
+ cap.release()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ process_videos()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "video-semantic-search"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Ben Tenmann <benji.tenmann@me.com>", "Sidney Radcliffe <sidneyradcliffe@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ packages = [{include = "video_semantic_search"}]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = ">=3.9,<3.9.7 || >3.9.7,<4.0"
12
+ streamlit = "^1.22.0"
13
+ pandas = "^2.0.1"
14
+ pyarrow = "^12.0.0"
15
+ # need to pin faiss-cpu to 1.6.5 because of segfaults when interacting with streamlit
16
+ # https://github.com/facebookresearch/faiss/issues/2099#issuecomment-961172708
17
+ # sidney use 1.7.4
18
+ faiss-cpu = "==1.7.4"
19
+ transformers = "^4.29.2"
20
+ torch = "^2.0.1"
21
+ torchvision = "^0.15.2"
22
+ urllib3 = "1.26.15"
23
+ yt-dlp = "^2023.3.4"
24
+ tqdm = "^4.65.0"
25
+ opencv-python = "^4.7.0.72"
26
+ youtube-dl = "^2021.12.17"
27
+
28
+ [tool.poetry.group.dev.dependencies]
29
+ notebook = "^6.5.4"
30
+ black = {extras = ["jupyter"], version = "^23.3.0"}
31
+ isort = "^5.12.0"
32
+ pytest = "^7.3.1"
33
+ jupyterlab = "^4.0.0"
34
+ nbconvert = "^7.4.0"
35
+ jupyter-contrib-nbextensions = "^0.7.0"
36
+
37
+ [build-system]
38
+ requires = ["poetry-core"]
39
+ build-backend = "poetry.core.masonry.api"
run_pipeline.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+
4
+ poetry run python pipeline/download_videos.py
5
+ poetry run python pipeline/process_videos.py
tests/pipeline/test_clip_wrapper.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from pipeline.clip_wrapper import ClipWrapper
4
+
5
+
6
+ def test_ClipWrapper():
7
+ clip_wrapper = ClipWrapper()
8
+
9
+ images = [torch.rand(3, 224, 224) for _ in range(2)]
10
+ assert clip_wrapper.images2vec(images).shape[-1] == 512
11
+
12
+ texts = ["a photo of a cat", "a photo of a dog"]
13
+ assert clip_wrapper.texts2vec(texts).shape[-1] == 512
tests/pipeline/test_download_videos.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline.download_videos import get_id
2
+
3
+
4
+ def test_get_id():
5
+ url1 = "https://www.youtube.com/watch?v=frYIj2FGmMA&foo=bar"
6
+ url2 = "https://www.youtube.com/watch?v=abcdefg"
7
+ url3 = "https://www.youtube.com/watch?foo=bar&v=xyz123"
8
+ assert get_id(url1) == "frYIj2FGmMA"
9
+ assert get_id(url2) == "abcdefg"
10
+ assert get_id(url3) == "xyz123"
video_semantic_search/__init__.py ADDED
File without changes
video_semantic_search/app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Final
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import pandas as pd
9
+ import streamlit as st
10
+
11
+ from pipeline import clip_wrapper
12
+
13
+
14
+ class SemanticSearcher:
15
+ def __init__(self, dataset: pd.DataFrame):
16
+ dim_columns = dataset.filter(regex="^dim_").columns
17
+
18
+ self.embedder = clip_wrapper.ClipWrapper().texts2vec
19
+ self.metadata = dataset.drop(columns=dim_columns)
20
+ self.index = faiss.IndexFlatIP(len(dim_columns))
21
+ self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32)))
22
+
23
+ def search(self, query: str) -> list["SearchResult"]:
24
+ v = self.embedder([query]).detach().numpy()
25
+ D, I = self.index.search(v, 10)
26
+ return [
27
+ SearchResult(
28
+ video_id=row["video_id"],
29
+ frame_idx=row["frame_idx"],
30
+ timestamp=row["timestamp"],
31
+ score=score,
32
+ )
33
+ for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows())
34
+ ]
35
+
36
+
37
+ DATASET_PATH: Final[str] = os.environ.get("DATASET_PATH", "data/dataset.parquet")
38
+ SEARCHER: Final[SemanticSearcher] = SemanticSearcher(pd.read_parquet(DATASET_PATH))
39
+
40
+
41
+ @dataclass
42
+ class SearchResult:
43
+ video_id: str
44
+ frame_idx: int
45
+ timestamp: float
46
+ score: float
47
+
48
+
49
+ def get_video_url(video_id: str, timestamp: float) -> str:
50
+ return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}"
51
+
52
+
53
+ def display_search_results(results: list[SearchResult]) -> None:
54
+ col_count = 3 # Number of videos per row
55
+
56
+ col_num = 0 # Counter to keep track of the current column
57
+ row = st.empty() # Placeholder for the current row
58
+
59
+ for i, result in enumerate(results):
60
+ if col_num == 0:
61
+ row = st.columns(col_count) # Create a new row of columns
62
+
63
+ with row[col_num]:
64
+ # Apply CSS styling to the video container
65
+ st.markdown(
66
+ """
67
+ <style>
68
+ .video-container {
69
+ position: relative;
70
+ padding-bottom: 56.25%;
71
+ padding-top: 30px;
72
+ height: 0;
73
+ overflow: hidden;
74
+ }
75
+ .video-container iframe,
76
+ .video-container object,
77
+ .video-container embed {
78
+ position: absolute;
79
+ top: 0;
80
+ left: 0;
81
+ width: 100%;
82
+ height: 100%;
83
+ }
84
+ </style>
85
+ """,
86
+ unsafe_allow_html=True,
87
+ )
88
+
89
+ # Display the embedded YouTube video
90
+ # st.video(get_video_url(result.video_id), start_time=int(result.timestamp))
91
+ # st.image(f"data/images/{result.video_id}/{result.frame_idx}.jpg")
92
+ with open(
93
+ f"data/images/{result.video_id}/{result.frame_idx}.jpg", "rb"
94
+ ) as f:
95
+ image = f.read()
96
+ encoded = base64.b64encode(image).decode()
97
+ st.markdown(
98
+ f"""
99
+ <a href="{get_video_url(result.video_id, result.timestamp)}">
100
+ <img src="data:image/jpeg;base64,{encoded}" alt="frame {result.frame_idx}" width="100%">
101
+ </a>
102
+ """,
103
+ unsafe_allow_html=True,
104
+ )
105
+
106
+ col_num += 1
107
+ if col_num >= col_count:
108
+ col_num = 0
109
+
110
+
111
+ def main():
112
+ st.set_page_config(page_title="video-semantic-search", layout="wide")
113
+ st.header("Video Semantic Search")
114
+
115
+ st.text_input("What are you looking for?", key="query")
116
+
117
+ query = st.session_state["query"]
118
+ if query:
119
+ display_search_results(SEARCHER.search(query))
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()