edia_full_en / modules /module_ann.py
nanom's picture
First commit
e8aad19
raw
history blame
2.07 kB
import time
from tqdm import tqdm
from annoy import AnnoyIndex
from typing import List
class TicToc:
def __init__(
self
) -> None:
self.i = None
def start(
self
) -> None:
self.i = time.time()
def stop(
self
) -> None:
f = time.time()
print(f - self.i, "seg.")
class Ann:
def __init__(
self,
words: List[str],
vectors: List,
coord: List,
) -> None:
self.words = words
self.vectors = vectors
self.coord = coord
self.tree = None
self.tt = TicToc()
self.availables_metrics = ['angular','euclidean','manhattan','hamming','dot']
def init(self,
n_trees: int=10,
metric: str='angular',
n_jobs: int=-1 # n_jobs=-1 Run over all CPU availables
) -> None:
assert(metric in self.availables_metrics), f"Error: The value of the parameter 'metric' can only be {self.availables_metrics}!"
print("\tInit tree...")
self.tt.start()
self.tree = AnnoyIndex(len(self.vectors[0]), metric=metric)
for i, v in tqdm(enumerate(self.vectors), total=len(self.vectors)):
self.tree.add_item(i, v)
self.tt.stop()
print("\tBuild tree...")
self.tt.start()
self.tree.build(n_trees=n_trees, n_jobs=n_jobs)
self.tt.stop()
def __getWordId(
self,
word: str
) -> int:
word_id = None
try:
word_id = self.words.index(word)
except:
pass
return word_id
def get(
self,
word: str,
n_neighbors: int=10
) -> List[str]:
word_id = self.__getWordId(word)
neighbors_list = None
if word_id != None:
neighbords_id = self.tree.get_nns_by_item(word_id, n_neighbors + 1)
neighbors_list = [self.words[idx] for idx in neighbords_id][1:]
else:
print(f"The word '{word}' does not exist")
return neighbors_list