colbert-acl / index.py
davidheineman's picture
fix indexing path
d2f9318
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevents deadlocks in ColBERT tokenization
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
import json
from colbert import Indexer, Searcher
from colbert.infra import Run, RunConfig, ColBERTConfig
INDEX_NAME = 'index'
ANTHOLOGY_PATH = 'anthology.bib'
COLLECTION_PATH = 'collection.json'
DATASET_PATH = 'dataset.json'
nbits = 2 # encode each dimension with 2 bits
doc_maxlen = 300 # truncate passages at 300 tokens
checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
def index_anthology(collection, index_name='index'):
with Run().context(RunConfig(nranks=1, experiment='notebook')): # nranks specifies the number of GPUs to use
config = ColBERTConfig(
doc_maxlen=doc_maxlen,
nbits=nbits,
kmeans_niters=4, # specifies the number of iterations of k-means clustering; 4 is a good and fast default.
index_path=INDEX_NAME
)
indexer = Indexer(
checkpoint=checkpoint,
config=config
)
indexer.index(
name=index_name,
collection=collection,
overwrite=True
)
def search_anthology(collection, index_name=INDEX_NAME):
with Run().context(RunConfig(nranks=0, experiment='notebook')):
searcher = Searcher(index=index_name, collection=collection)
queries = ["What are some recent examples of grammar checkers?"]
for query in queries:
print(f"#> {query}")
results = searcher.search(query, k=3) # Find the top-3 passages for this query
# Print out the top-k retrieved passages
for passage_id, passage_rank, passage_score in zip(*results):
print(f"\t [{passage_rank}] \t\t {passage_score:.1f} \t\t {searcher.collection[passage_id]}")
print(results)
if __name__ == '__main__':
# Load the parsed anthology
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
collection = json.loads(f.read())
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
dataset = json.loads(f.read())
index_anthology(collection, index_name=INDEX_NAME)
search_anthology(collection, index_name=INDEX_NAME)