sonajaht-demo / build_ann.py
adorkin's picture
Update build_ann.py
eb0ad1b verified
raw
history blame
882 Bytes
from annoy import AnnoyIndex
from safetensors import safe_open
from tqdm import trange
safetensors_path = "definitions.safetensors"
with safe_open(safetensors_path, framework="numpy") as f:
vectors = f.get_tensor("vectors")
num_vectors, vector_dim = vectors.shape
print(f"Loaded {num_vectors} vectors of dimension {vector_dim}")
index = AnnoyIndex(vector_dim, "angular")
for i in trange(num_vectors):
index.add_item(i, vectors[i])
num_trees = 25
index.build(num_trees)
index.save("definitions.ann")
query_vector = vectors[0]
num_neighbors = 5
nearest_neighbors = index.get_nns_by_vector(query_vector, num_neighbors)
print(f"Indices of {num_neighbors} nearest neighbors:", nearest_neighbors)
neighbors_with_distances = index.get_nns_by_vector(
query_vector, num_neighbors, include_distances=True
)
print("Neighbors with distances:", neighbors_with_distances)