File size: 1,371 Bytes
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import faiss
import numpy as np

class FaissNeighbors:
  def __init__(self):
    self.index = None
    self.y = None

  def fit(self, X, y):
    self.index = faiss.IndexFlatL2(X.shape[1])
    self.index.add(X.astype(np.float32))
    self.y = y
  
  def get_distances_and_indices(self, X, top_K=1000):
    distances, indices = self.index.search(X.astype(np.float32), k=top_K)
    return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
  
  def get_nearest_labels(self, X, top_K=1000):
    distances, indices = self.index.search(X.astype(np.float32), k=top_K)
    return np.copy(self.y[indices])


class FaissCosineNeighbors:
  def __init__(self):
    self.cindex = None
    self.y = None

  def fit(self, X, y):
    self.cindex = faiss.index_factory(X.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
    X = np.copy(X)
    X = X.astype(np.float32)
    faiss.normalize_L2(X)
    self.cindex.add(X)
    self.y = y
  
  def get_distances_and_indices(self, Q, topK):
    Q = np.copy(Q)
    faiss.normalize_L2(Q)
    distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
    return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
  
  def get_nearest_labels(self, Q, topK=1000):
    Q = np.copy(Q)
    faiss.normalize_L2(Q)
    distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
    return np.copy(self.y[indices])