欧卫
'add_app_files'
58627fa
raw
history blame
No virus
2.79 kB
import os
import time
import torch.multiprocessing as mp
from colbert.infra.run import Run
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.infra.launcher import Launcher
from colbert.utils.utils import create_directory, print_message
from colbert.indexing.collection_indexer import encode
class Indexer:
def __init__(self, checkpoint, config=None):
"""
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
"""
self.index_path = None
self.checkpoint = checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
self.configure(checkpoint=checkpoint)
def configure(self, **kw_args):
self.config.configure(**kw_args)
def get_index(self):
return self.index_path
def erase(self):
assert self.index_path is not None
directory = self.index_path
deleted = []
for filename in sorted(os.listdir(directory)):
filename = os.path.join(directory, filename)
delete = filename.endswith(".json")
delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
delete = delete or filename.endswith(".pt")
if delete:
deleted.append(filename)
if len(deleted):
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
time.sleep(20)
for filename in deleted:
os.remove(filename)
return deleted
def index(self, name, collection, overwrite=False):
assert overwrite in [True, False, 'reuse', 'resume']
self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
self.configure(bsize=64, partitions=None)
self.index_path = self.config.index_path_
index_does_not_exist = (not os.path.exists(self.config.index_path_))
assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_
create_directory(self.config.index_path_)
if overwrite is True:
self.erase()
if index_does_not_exist or overwrite != 'reuse':
self.__launch(collection)
return self.index_path
def __launch(self, collection):
manager = mp.Manager()
shared_lists = [manager.list() for _ in range(self.config.nranks)]
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
launcher = Launcher(encode)
launcher.launch(self.config, collection, shared_lists, shared_queues)