ngxquang
Add application file
db24a4e
raw
history blame
3.82 kB
from frame import Frame
from helper import OBJECT_MAP, get_hypernym_path
from numpy.polynomial import polynomial
from nltk.corpus import wordnet as wn
import json
import os
class NodeFrame:
def __init__(self, frame: Frame, p_list: list[float]) -> None:
self.frame = frame
self.p_list = p_list
self.p_total = self.calculate_p_total(p_list)
self.p_exactly = self.calculate_p_exactly(p_list)
def calculate_p_total(self, p_list: list[float]) -> float:
return sum(p_list)
def calculate_p_exactly(self, p_list: list[float]) -> list[float]:
result = [1]
p_list = [[1 - p, p] for p in p_list]
for p in p_list:
result = polynomial.polymul(result, p)
return list(result)
def p_of(self, amount: int) -> float:
if amount < len(self.p_exactly):
return self.p_exactly[amount]
else:
return self.p_exactly[-1] * (0.1 ** (amount - len(self.p_exactly) + 1))
def serialize(self) -> dict:
return {
'frame': self.frame.serialize(),
'p_list': self.p_list,
}
class Node:
def __init__(self, node_frames: list[NodeFrame]) -> None:
self.node_frames = node_frames
self.children = {}
class Trie:
def __init__(self) -> None:
self.root = Node([])
def insert(self, node_frame: NodeFrame, path: list[str]) -> None:
node = self.root
for word in path:
if word not in node.children:
node.children[word] = Node([])
node = node.children[word]
node.node_frames.append(node_frame)
def search(self, path: list[str]) -> list[NodeFrame]:
node = self.root
for word in path:
if word not in node.children:
return []
node = node.children[word]
return self.search_all_children(node)
def search_all_children(self, node: Node) -> list[NodeFrame]:
result = []
if len(node.node_frames) > 0:
result.extend(node.node_frames)
for child in node.children.values():
result.extend(self.search_all_children(child))
return result
def load_from_dir(self, dir: str) -> None:
for path, _, files in os.walk(dir):
for file in files:
if file.endswith('.json'):
data = json.load(open(os.path.join(path, file)))
video = file[:-5]
for frame_name, frame_data in data.items():
for object, p_list in frame_data.items():
hypernym_path = get_hypernym_path(object)
self.insert(NodeFrame(Frame(video=video, frame_name=frame_name), p_list), hypernym_path)
def save_to_cache(self, cache_path: str) -> None:
json.dump(self.serialize(), open(cache_path, 'w'))
def load_from_cache(self, cache_path: str) -> None:
self.deserialize(json.load(open(cache_path)))
def serialize(self) -> dict:
output = {}
def dfs(node: Node, path: list[str]) -> None:
if len(node.node_frames) > 0:
output['/'.join(path)] = [node_frame.serialize() for node_frame in node.node_frames]
for word, child in node.children.items():
dfs(child, path + [word])
dfs(self.root, [])
return output
def deserialize(self, input):
for path, node_frames in input.items():
path = path.split('/')
for node_frame in node_frames:
self.insert(NodeFrame(Frame(id=node_frame['frame']['id']), node_frame['p_list']), path)