ngxquang commited on
Commit
db24a4e
1 Parent(s): 4f03b5c

Add application file

Browse files
Files changed (14) hide show
  1. .gitattributes +1 -0
  2. .gitignore +165 -0
  3. Dockerfile +34 -0
  4. cache.json +3 -0
  5. candidate.py +13 -0
  6. classes.json +83 -0
  7. frame.py +17 -0
  8. helper.py +49 -0
  9. main.py +85 -0
  10. requirements.txt +35 -0
  11. searcher.py +45 -0
  12. settings.py +2 -0
  13. synsets.txt +0 -0
  14. trie.py +109 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cache.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .python-version
163
+
164
+ # cache.json
165
+ data/
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN apt-get update && \
4
+ apt-get install git gsutil -y && \
5
+ apt clean && \
6
+ rm -rf /var/cache/apt/*
7
+
8
+ WORKDIR /code
9
+
10
+ COPY requirements.txt /code/requirements.txt
11
+
12
+ # PYTHONDONTWRITEBYTECODE=1: Disables the creation of .pyc files (compiled bytecode)
13
+ # PYTHONUNBUFFERED=1: Disables buffering of the standard output stream
14
+ # PYTHONIOENCODING: specifies the encoding to be used for the standard input, output, and error streams
15
+ ENV PYTHONDONTWRITEBYTECODE=1 \
16
+ PYTHONUNBUFFERED=1 \
17
+ PYTHONIOENCODING=utf-8
18
+
19
+ RUN pip install -U pip && \
20
+ pip install --no-cache-dir -r /code/requirements.txt && \
21
+ python -m nltk.downloader -d /usr/share/nltk_data wordnet
22
+
23
+ RUN useradd -m -u 1000 user
24
+
25
+ USER user
26
+
27
+ ENV HOME=/home/user \
28
+ PATH=/home/user/.local/bin:$PATH
29
+
30
+ WORKDIR $HOME/app
31
+
32
+ COPY --chown=user . $HOME/app
33
+
34
+ CMD ["python", "main.py"]
cache.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db9e7ad9f9620753805e23353fcd3f9a972ae098d7678510949b861e598e7b28
3
+ size 75224556
candidate.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from frame import Frame
2
+
3
+ class Candidate:
4
+ def __init__(self, frame: Frame, score: float):
5
+ self.frame = frame
6
+ self.score = score
7
+
8
+ def serialize(self) -> dict:
9
+ return {
10
+ 'video': self.frame.video,
11
+ 'frame_name': self.frame.frame_name + '.jpg',
12
+ 'score': self.score,
13
+ }
classes.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "_background_",
3
+ "1": "person",
4
+ "2": "bicycle",
5
+ "3": "car",
6
+ "4": "motorcycle",
7
+ "5": "airplane",
8
+ "6": "bus",
9
+ "7": "train",
10
+ "8": "truck",
11
+ "9": "boat",
12
+ "10": "traffic light",
13
+ "11": "fire hydrant",
14
+ "12": "stop sign",
15
+ "13": "parking meter",
16
+ "14": "bench",
17
+ "15": "bird",
18
+ "16": "cat",
19
+ "17": "dog",
20
+ "18": "horse",
21
+ "19": "sheep",
22
+ "20": "cow",
23
+ "21": "elephant",
24
+ "22": "bear",
25
+ "23": "zebra",
26
+ "24": "giraffe",
27
+ "25": "backpack",
28
+ "26": "umbrella",
29
+ "27": "handbag",
30
+ "28": "tie",
31
+ "29": "suitcase",
32
+ "30": "frisbee",
33
+ "31": "skis",
34
+ "32": "snowboard",
35
+ "33": "sports ball",
36
+ "34": "kite",
37
+ "35": "baseball bat",
38
+ "36": "baseball glove",
39
+ "37": "skateboard",
40
+ "38": "surfboard",
41
+ "39": "tennis racket",
42
+ "40": "bottle",
43
+ "41": "wine glass",
44
+ "42": "cup",
45
+ "43": "fork",
46
+ "44": "knife",
47
+ "45": "spoon",
48
+ "46": "bowl",
49
+ "47": "banana",
50
+ "48": "apple",
51
+ "49": "sandwich",
52
+ "50": "orange",
53
+ "51": "broccoli",
54
+ "52": "carrot",
55
+ "53": "hot dog",
56
+ "54": "pizza",
57
+ "55": "donut",
58
+ "56": "cake",
59
+ "57": "chair",
60
+ "58": "couch",
61
+ "59": "potted plant",
62
+ "60": "bed",
63
+ "61": "dining table",
64
+ "62": "toilet",
65
+ "63": "tv",
66
+ "64": "laptop",
67
+ "65": "mouse",
68
+ "66": "remote",
69
+ "67": "keyboard",
70
+ "68": "cell phone",
71
+ "69": "microwave",
72
+ "70": "oven",
73
+ "71": "toaster",
74
+ "72": "sink",
75
+ "73": "refrigerator",
76
+ "74": "book",
77
+ "75": "clock",
78
+ "76": "vase",
79
+ "77": "scissors",
80
+ "78": "teddy bear",
81
+ "79": "hair drier",
82
+ "80": "toothbrush"
83
+ }
frame.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Frame:
2
+ def __init__(self, video: str|None = None, frame_name: str|None = None, id: str|None = None) -> None:
3
+ if id == None:
4
+ self.video = video
5
+ self.frame_name = frame_name
6
+ self.id = video + '/' + frame_name
7
+ else:
8
+ self.id = id
9
+ self.video, self.frame_name = self.id.split('/')
10
+
11
+ def serialize(self) -> dict:
12
+ return {
13
+ 'video': self.video,
14
+ 'frame_name': self.frame_name,
15
+ 'id': self.id,
16
+ }
17
+
helper.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nltk.corpus import wordnet as wn
2
+ from fastapi.responses import JSONResponse
3
+ from google.cloud import storage
4
+
5
+ OBJECT_MAP = {
6
+ "stop_sign": "traffic_control",
7
+ "sports_ball": "ball",
8
+ "wine_glass": "glass",
9
+ "potted_plant": "plant",
10
+ }
11
+
12
+ MULTI_OBJECT_BONUS = 1.0
13
+
14
+ def get_hypernym_path(object):
15
+ object = object.lower().replace(' ', '_')
16
+ if object in OBJECT_MAP:
17
+ object = OBJECT_MAP[object]
18
+ synset = wn.synsets(object, 'n')
19
+ if len(synset) == 0:
20
+ hypernym_path = [object]
21
+ else:
22
+ hypernym_path = [s.name()[:-5] for s in synset[0].hypernym_paths()[0]]
23
+ return hypernym_path
24
+
25
+ def parse_query(query: str) -> list[dict]:
26
+ result = []
27
+ query = query.split(',')
28
+ for q in query:
29
+ q = q.split(':')
30
+ if len(q) == 2:
31
+ result.append({'object': q[0].strip(), 'amount': int(q[1].strip())})
32
+ else:
33
+ result.append({'object': q[0].strip(), 'amount': 'any'})
34
+ return result
35
+
36
+ def make_response(status: int, message: str, data: dict|None = None) -> dict:
37
+ return JSONResponse(content={
38
+ 'status': status,
39
+ 'message': message,
40
+ 'data': data,
41
+ })
42
+
43
+ def download_from_bucket(dir: str) -> None:
44
+ client = storage.Client()
45
+ bucket = client.get_bucket('thangtd1')
46
+ blobs = bucket.list_blobs(prefix='object-detection')
47
+ for blob in blobs:
48
+ if blob.name.endswith('.json'):
49
+ blob.download_to_filename(dir + '/' + blob.name[16:])
main.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from searcher import Searcher
2
+ from trie import Trie
3
+ from helper import parse_query, make_response, download_from_bucket
4
+
5
+ from fastapi import FastAPI, Request, status
6
+ from fastapi.exceptions import RequestValidationError
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import RedirectResponse
9
+ from pydantic import BaseModel
10
+
11
+ import settings
12
+ import os
13
+ import time
14
+
15
+ trie = Trie()
16
+ searcher = Searcher(trie)
17
+ app = FastAPI(title="Object Search")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_headers=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ )
26
+
27
+
28
+ @app.exception_handler(RequestValidationError)
29
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
30
+ details = exc.errors()
31
+ error_details = []
32
+
33
+ for error in details:
34
+ error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
35
+
36
+ return make_response(status=200, message="Bad Request", data=error_details)
37
+
38
+
39
+ @app.get("/", include_in_schema=False)
40
+ async def root() -> None:
41
+ return RedirectResponse("/docs")
42
+
43
+
44
+ @app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
45
+ async def perform_healthcheck() -> None:
46
+ return make_response(status=200, message="OK")
47
+
48
+
49
+ class Query(BaseModel):
50
+ query_text: str
51
+ topk: int
52
+
53
+
54
+ @app.post("/search", status_code=status.HTTP_200_OK, tags=["search"])
55
+ async def search(query: Query) -> None:
56
+ topk = query.topk
57
+ query = parse_query(query.query_text)
58
+
59
+ candidates = searcher.search(query, topk)
60
+ data = [candidate.serialize() for candidate in candidates]
61
+ return make_response(status=200, message="OK", data=data)
62
+
63
+
64
+ @app.on_event("startup")
65
+ async def startup_event():
66
+ if os.path.exists("cache.json"):
67
+ start_time = time.time()
68
+ trie.load_from_cache("cache.json")
69
+ print("Load from cache took %.2f seconds" % (time.time() - start_time))
70
+ else:
71
+ if not os.path.exists("data"):
72
+ os.mkdir("data")
73
+ start_time = time.time()
74
+ download_from_bucket("data")
75
+ print("Download from bucket took %.2f seconds" % (time.time() - start_time))
76
+ start_time = time.time()
77
+ trie.load_from_dir("data")
78
+ trie.save_to_cache("cache.json")
79
+ print("Load from directory took %.2f seconds" % (time.time() - start_time))
80
+
81
+
82
+ if __name__ == "__main__":
83
+ import uvicorn
84
+
85
+ uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.6.0
2
+ anyio==3.7.1
3
+ beautifulsoup4==4.12.2
4
+ cachetools==5.3.1
5
+ certifi==2023.7.22
6
+ charset-normalizer==3.3.0
7
+ click==8.1.7
8
+ fastapi==0.104.0
9
+ google-api-core==2.12.0
10
+ google-auth==2.23.3
11
+ google-cloud-core==2.3.3
12
+ google-cloud-storage==2.12.0
13
+ google-crc32c==1.5.0
14
+ google-resumable-media==2.6.0
15
+ googleapis-common-protos==1.61.0
16
+ h11==0.14.0
17
+ idna==3.4
18
+ joblib==1.3.2
19
+ nltk==3.8.1
20
+ numpy==1.26.1
21
+ protobuf==4.24.4
22
+ pyasn1==0.5.0
23
+ pyasn1-modules==0.3.0
24
+ pydantic==2.4.2
25
+ pydantic_core==2.10.1
26
+ regex==2023.10.3
27
+ requests==2.31.0
28
+ rsa==4.9
29
+ sniffio==1.3.0
30
+ soupsieve==2.5
31
+ starlette==0.27.0
32
+ tqdm==4.66.1
33
+ typing_extensions==4.8.0
34
+ urllib3==2.0.7
35
+ uvicorn==0.23.2
searcher.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from candidate import Candidate
2
+ from frame import Frame
3
+ from helper import MULTI_OBJECT_BONUS, get_hypernym_path
4
+
5
+ from nltk.corpus import wordnet as wn
6
+
7
+
8
+ class Searcher:
9
+ def __init__(self, trie):
10
+ self.trie = trie
11
+
12
+ def search(self, query: list[dict[str, str]], topk: int) -> list[Candidate]:
13
+ candidates: dict[str, float] = {}
14
+ for q in query:
15
+ this_candidates: list[Candidate] = []
16
+ object, amount = q['object'], q['amount']
17
+ hypernym_path = get_hypernym_path(object)
18
+ node_frames = self.trie.search(hypernym_path)
19
+ if amount == 'any':
20
+ this_candidates.extend([Candidate(node_frame.frame, node_frame.p_total) for node_frame in node_frames])
21
+ elif amount == int(amount):
22
+ this_candidates.extend([Candidate(node_frame.frame, node_frame.p_of(amount) * amount) for node_frame in node_frames])
23
+ else:
24
+ raise ValueError('Amount must be an integer or "any"')
25
+ for candidate in this_candidates:
26
+ if candidate.frame.id not in candidates:
27
+ candidates[candidate.frame.id] = candidate.score
28
+ else:
29
+ candidates[candidate.frame.id] += candidate.score + MULTI_OBJECT_BONUS
30
+
31
+ candidates = [Candidate(Frame(id=id), score) for id, score in candidates.items()]
32
+ candidates = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)
33
+ if len(candidates) > topk:
34
+ candidates = candidates[:topk]
35
+ return candidates
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
settings.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ HOST = "0.0.0.0"
2
+ PORT = 7860
synsets.txt ADDED
The diff for this file is too large to render. See raw diff
 
trie.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from frame import Frame
2
+ from helper import OBJECT_MAP, get_hypernym_path
3
+
4
+ from numpy.polynomial import polynomial
5
+ from nltk.corpus import wordnet as wn
6
+ import json
7
+ import os
8
+
9
+ class NodeFrame:
10
+ def __init__(self, frame: Frame, p_list: list[float]) -> None:
11
+ self.frame = frame
12
+ self.p_list = p_list
13
+ self.p_total = self.calculate_p_total(p_list)
14
+ self.p_exactly = self.calculate_p_exactly(p_list)
15
+
16
+ def calculate_p_total(self, p_list: list[float]) -> float:
17
+ return sum(p_list)
18
+
19
+ def calculate_p_exactly(self, p_list: list[float]) -> list[float]:
20
+ result = [1]
21
+ p_list = [[1 - p, p] for p in p_list]
22
+ for p in p_list:
23
+ result = polynomial.polymul(result, p)
24
+ return list(result)
25
+
26
+ def p_of(self, amount: int) -> float:
27
+ if amount < len(self.p_exactly):
28
+ return self.p_exactly[amount]
29
+ else:
30
+ return self.p_exactly[-1] * (0.1 ** (amount - len(self.p_exactly) + 1))
31
+
32
+ def serialize(self) -> dict:
33
+ return {
34
+ 'frame': self.frame.serialize(),
35
+ 'p_list': self.p_list,
36
+ }
37
+
38
+ class Node:
39
+ def __init__(self, node_frames: list[NodeFrame]) -> None:
40
+ self.node_frames = node_frames
41
+ self.children = {}
42
+
43
+ class Trie:
44
+ def __init__(self) -> None:
45
+ self.root = Node([])
46
+
47
+ def insert(self, node_frame: NodeFrame, path: list[str]) -> None:
48
+ node = self.root
49
+ for word in path:
50
+ if word not in node.children:
51
+ node.children[word] = Node([])
52
+ node = node.children[word]
53
+ node.node_frames.append(node_frame)
54
+
55
+ def search(self, path: list[str]) -> list[NodeFrame]:
56
+ node = self.root
57
+ for word in path:
58
+ if word not in node.children:
59
+ return []
60
+ node = node.children[word]
61
+ return self.search_all_children(node)
62
+
63
+ def search_all_children(self, node: Node) -> list[NodeFrame]:
64
+ result = []
65
+ if len(node.node_frames) > 0:
66
+ result.extend(node.node_frames)
67
+ for child in node.children.values():
68
+ result.extend(self.search_all_children(child))
69
+ return result
70
+
71
+ def load_from_dir(self, dir: str) -> None:
72
+ for path, _, files in os.walk(dir):
73
+ for file in files:
74
+ if file.endswith('.json'):
75
+ data = json.load(open(os.path.join(path, file)))
76
+ video = file[:-5]
77
+ for frame_name, frame_data in data.items():
78
+ for object, p_list in frame_data.items():
79
+ hypernym_path = get_hypernym_path(object)
80
+ self.insert(NodeFrame(Frame(video=video, frame_name=frame_name), p_list), hypernym_path)
81
+
82
+ def save_to_cache(self, cache_path: str) -> None:
83
+ json.dump(self.serialize(), open(cache_path, 'w'))
84
+
85
+ def load_from_cache(self, cache_path: str) -> None:
86
+ self.deserialize(json.load(open(cache_path)))
87
+
88
+ def serialize(self) -> dict:
89
+ output = {}
90
+ def dfs(node: Node, path: list[str]) -> None:
91
+ if len(node.node_frames) > 0:
92
+ output['/'.join(path)] = [node_frame.serialize() for node_frame in node.node_frames]
93
+ for word, child in node.children.items():
94
+ dfs(child, path + [word])
95
+ dfs(self.root, [])
96
+ return output
97
+
98
+ def deserialize(self, input):
99
+ for path, node_frames in input.items():
100
+ path = path.split('/')
101
+ for node_frame in node_frames:
102
+ self.insert(NodeFrame(Frame(id=node_frame['frame']['id']), node_frame['p_list']), path)
103
+
104
+
105
+
106
+
107
+
108
+
109
+