colbert-xm / custom.py
antoinelouis's picture
Create custom.py
075dfca verified
import os
from langdetect import detect
import torch.multiprocessing as mp
from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig, Run
from colbert.utils.utils import print_message
from colbert.data.collection import Collection
from colbert.modeling.checkpoint import Checkpoint
from colbert.indexing.index_saver import IndexSaver
from colbert.search.index_storage import IndexScorer
from colbert.infra.launcher import Launcher, print_memory_stats
from colbert.indexing.collection_encoder import CollectionEncoder
from colbert.indexing.collection_indexer import CollectionIndexer
MMARCO_LANGUAGES = {
'ar': ('arabic', 'ar_AR'),
'de': ('german', 'de_DE'),
'en': ('english', 'en_XX'),
'es': ('spanish', 'es_XX'),
'fr': ('french', 'fr_XX'),
'hi': ('hindi', 'hi_IN'),
'id': ('indonesian', 'id_ID'),
'it': ('italian', 'it_IT'),
'ja': ('japanese', 'ja_XX'),
'nl': ('dutch', 'nl_XX'),
'pt': ('portuguese', 'pt_XX'),
'ru': ('russian', 'ru_RU'),
'vi': ('vietnamese', 'vi_VN'),
'zh': ('chinese', 'zh_CN'),
}
MRTYDI_LANGUAGES = {
'ar': ('arabic', 'ar_AR'),
'bn': ('bengali', 'bn_IN'),
'en': ('english', 'en_XX'),
'fi': ('finnish', 'fi_FI'),
'id': ('indonesian', 'id_ID'),
'ja': ('japanese', 'ja_XX'),
'ko': ('korean', 'ko_KR'),
'ru': ('russian', 'ru_RU'),
'sw': ('swahili', 'sw_KE'),
'te': ('telugu', 'te_IN'),
'th': ('thai', 'th_TH'),
}
MIRACL_LANGUAGES = {
'ar': ('arabic', 'ar_AR'),
'bn': ('bengali', 'bn_IN'),
'en': ('english', 'en_XX'),
'es': ('spanish', 'es_XX'),
'fa': ('persian', 'fa_IR'),
'fi': ('finnish', 'fi_FI'),
'fr': ('french', 'fr_XX'),
'hi': ('hindi', 'hi_IN'),
'id': ('indonesian', 'id_ID'),
'ja': ('japanese', 'ja_XX'),
'ko': ('korean', 'ko_KR'),
'ru': ('russian', 'ru_RU'),
'sw': ('swahili', 'sw_KE'),
'te': ('telugu', 'te_IN'),
'th': ('thai', 'th_TH'),
'zh': ('chinese', 'zh_CN'),
}
ALL_LANGUAGES = {**MMARCO_LANGUAGES, **MRTYDI_LANGUAGES, **MIRACL_LANGUAGES}
def set_xmod_language(model, lang:str):
"""
Set the default language code for the model. This is used when the language is not specified in the input.
Source: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/xmod/modeling_xmod.py#L687
"""
lang = lang.split('-')[0]
if (value := ALL_LANGUAGES.get(lang)) is not None:
model.set_default_language(value[1])
else:
raise KeyError(f"Language {lang} not supported.")
#-----------------------------------------------------------------------------------------------------------------#
# INDEXER
#-----------------------------------------------------------------------------------------------------------------#
class CustomIndexer(Indexer):
def __launch(self, collection):
manager = mp.Manager()
shared_lists = [manager.list() for _ in range(self.config.nranks)]
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
launcher = Launcher(custom_encode)
launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)
def custom_encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
encoder = CustomCollectionIndexer(config=config, collection=collection, verbose=verbose)
encoder.run(shared_lists)
class CustomCollectionIndexer(CollectionIndexer):
def __init__(self, config: ColBERTConfig, collection, verbose=2):
self.verbose = verbose
self.config = config
self.rank, self.nranks = self.config.rank, self.config.nranks
self.use_gpu = self.config.total_visible_gpus > 0
if self.config.rank == 0 and self.verbose > 1:
self.config.help()
self.collection = Collection.cast(collection)
self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config)
if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
language = detect(self.collection.__getitem__(0))
Run().print_main(f"#> Setting X-MOD language adapters to {language}.")
set_xmod_language(self.checkpoint.bert, lang=language)
if self.use_gpu:
self.checkpoint = self.checkpoint.cuda()
self.encoder = CollectionEncoder(config, self.checkpoint)
self.saver = IndexSaver(config)
print_memory_stats(f'RANK:{self.rank}')
#-----------------------------------------------------------------------------------------------------------------#
# SEARCHER
#-----------------------------------------------------------------------------------------------------------------#
class CustomSearcher(Searcher):
def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3):
self.verbose = verbose
if self.verbose > 1:
print_memory_stats()
initial_config = ColBERTConfig.from_existing(config, Run().config)
default_index_root = initial_config.index_root_
index_root = index_root if index_root else default_index_root
self.index = os.path.join(index_root, index)
self.index_config = ColBERTConfig.load_from_index(self.index)
self.checkpoint = checkpoint or self.index_config.checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint)
self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config)
self.collection = Collection.cast(collection or self.config.collection)
self.configure(checkpoint=self.checkpoint, collection=self.collection)
self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose)
if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
language = detect(self.collection.__getitem__(0))
print_message(f"#> Setting X-MOD language adapters to {language}.")
set_xmod_language(self.checkpoint.bert, lang=language)
use_gpu = self.config.total_visible_gpus > 0
if use_gpu:
self.checkpoint = self.checkpoint.cuda()
load_index_with_mmap = self.config.load_index_with_mmap
if load_index_with_mmap and use_gpu:
raise ValueError(f"Memory-mapped index can only be used with CPU!")
self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap)
print_memory_stats()