MAS-AI-0000 commited on
Commit
4687d89
·
verified ·
1 Parent(s): 02cd288

Update detree/utils/index.py

Browse files
Files changed (1) hide show
  1. detree/utils/index.py +112 -105
detree/utils/index.py CHANGED
@@ -1,105 +1,112 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import pickle
9
- from typing import List, Tuple
10
-
11
- import faiss
12
- import numpy as np
13
- from tqdm import tqdm
14
-
15
- class Indexer(object):
16
-
17
- def __init__(self, vector_sz, n_subquantizers=0, n_bits=16):
18
- # if n_subquantizers > 0:
19
- # self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
20
- # else:
21
- self.vector_sz = vector_sz
22
- self.index = self._create_sharded_index()
23
- self.index_id_to_db_id = []
24
- self.label_dict = {}
25
- # self.index = faiss.IndexFlatIP(vector_sz)
26
-
27
- # self.index = faiss.index_cpu_to_all_gpus(self.index)
28
- # #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
29
- # self.index_id_to_db_id = []
30
- # self.label_dict = {}
31
-
32
- def _create_sharded_index(self):
33
- # Determine the number of available GPUs
34
- ngpu = faiss.get_num_gpus()
35
- # Create an IndexShards object with successive_ids=True to keep ids globally unique
36
- index = faiss.IndexShards(self.vector_sz, True, True)
37
- # Create a sub-index for each GPU and add it to the IndexShards container
38
- for i in range(ngpu):
39
- # Create a standard GPU resource object
40
- res = faiss.StandardGpuResources()
41
- # Configure the GPU index
42
- flat_config = faiss.GpuIndexFlatConfig()
43
- # flat_config.useFloat16 = True # enable to reduce memory usage with half precision
44
- flat_config.device = i # assign the GPU device id
45
- # Create the GPU index
46
- sub_index = faiss.GpuIndexFlatIP(res, self.vector_sz, flat_config)
47
- # Add the sub-index into the sharded index
48
- index.add_shard(sub_index)
49
- return index
50
-
51
- def index_data(self, ids, embeddings):
52
- self._update_id_mapping(ids)
53
- # embeddings = embeddings
54
- # if not self.index.is_trained:
55
- # self.index.train(embeddings)
56
- self.index.add(embeddings)
57
-
58
- print(f'Total data indexed {self.index.ntotal}')
59
-
60
- def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
61
- # query_vectors = query_vectors
62
- result = []
63
- nbatch = (len(query_vectors)-1) // index_batch_size + 1
64
- for k in tqdm(range(nbatch)):
65
- start_idx = k*index_batch_size
66
- end_idx = min((k+1)*index_batch_size, len(query_vectors))
67
- q = query_vectors[start_idx: end_idx]
68
- scores, indexes = self.index.search(q, top_docs)
69
- # convert to external ids
70
- db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
71
- db_labels = [[self.label_dict[self.index_id_to_db_id[i]] for i in query_top_idxs] for query_top_idxs in indexes]
72
- result.extend([(db_ids[i], scores[i],db_labels[i]) for i in range(len(db_ids))])
73
- return result
74
-
75
- def serialize(self, dir_path):
76
- index_file = os.path.join(dir_path, 'index.faiss')
77
- meta_file = os.path.join(dir_path, 'index_meta.faiss')
78
- print(f'Serializing index to {index_file}, meta data to {meta_file}')
79
-
80
- faiss.write_index(self.index, index_file)
81
- with open(meta_file, mode='wb') as f:
82
- pickle.dump(self.index_id_to_db_id, f)
83
-
84
- def deserialize_from(self, dir_path):
85
- index_file = os.path.join(dir_path, 'index.faiss')
86
- meta_file = os.path.join(dir_path, 'index_meta.faiss')
87
- print(f'Loading index from {index_file}, meta data from {meta_file}')
88
-
89
- self.index = faiss.read_index(index_file)
90
- print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
91
-
92
- with open(meta_file, "rb") as reader:
93
- self.index_id_to_db_id = pickle.load(reader)
94
- assert len(
95
- self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
96
-
97
- def _update_id_mapping(self, db_ids: List):
98
- #new_ids = np.array(db_ids, dtype=np.int64)
99
- #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
100
- self.index_id_to_db_id.extend(db_ids)
101
-
102
- def reset(self):
103
- self.index.reset()
104
- self.index_id_to_db_id = []
105
- print(f'Index reset, total data indexed {self.index.ntotal}')
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import pickle
9
+ from typing import List, Tuple
10
+
11
+ import faiss
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ class Indexer(object):
16
+
17
+ def __init__(self, vector_sz, n_subquantizers=0, n_bits=16):
18
+ # if n_subquantizers > 0:
19
+ # self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
20
+ # else:
21
+ self.vector_sz = vector_sz
22
+ self.index = self._create_sharded_index()
23
+ self.index_id_to_db_id = []
24
+ self.label_dict = {}
25
+ # self.index = faiss.IndexFlatIP(vector_sz)
26
+
27
+ # self.index = faiss.index_cpu_to_all_gpus(self.index)
28
+ # #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
29
+ # self.index_id_to_db_id = []
30
+ # self.label_dict = {}
31
+
32
+ def _create_sharded_index(self):
33
+ # Determine the number of available GPUs
34
+ ngpu = faiss.get_num_gpus()
35
+
36
+ # If no GPUs available, use CPU index
37
+ if ngpu == 0:
38
+ print("No GPUs detected. Using CPU index.")
39
+ return faiss.IndexFlatIP(self.vector_sz)
40
+
41
+
42
+ # Create an IndexShards object with successive_ids=True to keep ids globally unique
43
+ index = faiss.IndexShards(self.vector_sz, True, True)
44
+ # Create a sub-index for each GPU and add it to the IndexShards container
45
+ for i in range(ngpu):
46
+ # Create a standard GPU resource object
47
+ res = faiss.StandardGpuResources()
48
+ # Configure the GPU index
49
+ flat_config = faiss.GpuIndexFlatConfig()
50
+ # flat_config.useFloat16 = True # enable to reduce memory usage with half precision
51
+ flat_config.device = i # assign the GPU device id
52
+ # Create the GPU index
53
+ sub_index = faiss.GpuIndexFlatIP(res, self.vector_sz, flat_config)
54
+ # Add the sub-index into the sharded index
55
+ index.add_shard(sub_index)
56
+ return index
57
+
58
+ def index_data(self, ids, embeddings):
59
+ self._update_id_mapping(ids)
60
+ # embeddings = embeddings
61
+ # if not self.index.is_trained:
62
+ # self.index.train(embeddings)
63
+ self.index.add(embeddings)
64
+
65
+ print(f'Total data indexed {self.index.ntotal}')
66
+
67
+ def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
68
+ # query_vectors = query_vectors
69
+ result = []
70
+ nbatch = (len(query_vectors)-1) // index_batch_size + 1
71
+ for k in tqdm(range(nbatch)):
72
+ start_idx = k*index_batch_size
73
+ end_idx = min((k+1)*index_batch_size, len(query_vectors))
74
+ q = query_vectors[start_idx: end_idx]
75
+ scores, indexes = self.index.search(q, top_docs)
76
+ # convert to external ids
77
+ db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
78
+ db_labels = [[self.label_dict[self.index_id_to_db_id[i]] for i in query_top_idxs] for query_top_idxs in indexes]
79
+ result.extend([(db_ids[i], scores[i],db_labels[i]) for i in range(len(db_ids))])
80
+ return result
81
+
82
+ def serialize(self, dir_path):
83
+ index_file = os.path.join(dir_path, 'index.faiss')
84
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
85
+ print(f'Serializing index to {index_file}, meta data to {meta_file}')
86
+
87
+ faiss.write_index(self.index, index_file)
88
+ with open(meta_file, mode='wb') as f:
89
+ pickle.dump(self.index_id_to_db_id, f)
90
+
91
+ def deserialize_from(self, dir_path):
92
+ index_file = os.path.join(dir_path, 'index.faiss')
93
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
94
+ print(f'Loading index from {index_file}, meta data from {meta_file}')
95
+
96
+ self.index = faiss.read_index(index_file)
97
+ print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
98
+
99
+ with open(meta_file, "rb") as reader:
100
+ self.index_id_to_db_id = pickle.load(reader)
101
+ assert len(
102
+ self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
103
+
104
+ def _update_id_mapping(self, db_ids: List):
105
+ #new_ids = np.array(db_ids, dtype=np.int64)
106
+ #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
107
+ self.index_id_to_db_id.extend(db_ids)
108
+
109
+ def reset(self):
110
+ self.index.reset()
111
+ self.index_id_to_db_id = []
112
+ print(f'Index reset, total data indexed {self.index.ntotal}')