import os from langdetect import detect import torch.multiprocessing as mp from colbert import Indexer, Searcher from colbert.infra import ColBERTConfig, Run from colbert.utils.utils import print_message from colbert.data.collection import Collection from colbert.modeling.checkpoint import Checkpoint from colbert.indexing.index_saver import IndexSaver from colbert.search.index_storage import IndexScorer from colbert.infra.launcher import Launcher, print_memory_stats from colbert.indexing.collection_encoder import CollectionEncoder from colbert.indexing.collection_indexer import CollectionIndexer MMARCO_LANGUAGES = { 'ar': ('arabic', 'ar_AR'), 'de': ('german', 'de_DE'), 'en': ('english', 'en_XX'), 'es': ('spanish', 'es_XX'), 'fr': ('french', 'fr_XX'), 'hi': ('hindi', 'hi_IN'), 'id': ('indonesian', 'id_ID'), 'it': ('italian', 'it_IT'), 'ja': ('japanese', 'ja_XX'), 'nl': ('dutch', 'nl_XX'), 'pt': ('portuguese', 'pt_XX'), 'ru': ('russian', 'ru_RU'), 'vi': ('vietnamese', 'vi_VN'), 'zh': ('chinese', 'zh_CN'), } MRTYDI_LANGUAGES = { 'ar': ('arabic', 'ar_AR'), 'bn': ('bengali', 'bn_IN'), 'en': ('english', 'en_XX'), 'fi': ('finnish', 'fi_FI'), 'id': ('indonesian', 'id_ID'), 'ja': ('japanese', 'ja_XX'), 'ko': ('korean', 'ko_KR'), 'ru': ('russian', 'ru_RU'), 'sw': ('swahili', 'sw_KE'), 'te': ('telugu', 'te_IN'), 'th': ('thai', 'th_TH'), } MIRACL_LANGUAGES = { 'ar': ('arabic', 'ar_AR'), 'bn': ('bengali', 'bn_IN'), 'en': ('english', 'en_XX'), 'es': ('spanish', 'es_XX'), 'fa': ('persian', 'fa_IR'), 'fi': ('finnish', 'fi_FI'), 'fr': ('french', 'fr_XX'), 'hi': ('hindi', 'hi_IN'), 'id': ('indonesian', 'id_ID'), 'ja': ('japanese', 'ja_XX'), 'ko': ('korean', 'ko_KR'), 'ru': ('russian', 'ru_RU'), 'sw': ('swahili', 'sw_KE'), 'te': ('telugu', 'te_IN'), 'th': ('thai', 'th_TH'), 'zh': ('chinese', 'zh_CN'), } ALL_LANGUAGES = {**MMARCO_LANGUAGES, **MRTYDI_LANGUAGES, **MIRACL_LANGUAGES} def set_xmod_language(model, lang:str): """ Set the default language code for the model. This is used when the language is not specified in the input. Source: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/xmod/modeling_xmod.py#L687 """ lang = lang.split('-')[0] if (value := ALL_LANGUAGES.get(lang)) is not None: model.set_default_language(value[1]) else: raise KeyError(f"Language {lang} not supported.") #-----------------------------------------------------------------------------------------------------------------# # INDEXER #-----------------------------------------------------------------------------------------------------------------# class CustomIndexer(Indexer): def __launch(self, collection): manager = mp.Manager() shared_lists = [manager.list() for _ in range(self.config.nranks)] shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)] launcher = Launcher(custom_encode) launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose) def custom_encode(config, collection, shared_lists, shared_queues, verbose: int = 3): encoder = CustomCollectionIndexer(config=config, collection=collection, verbose=verbose) encoder.run(shared_lists) class CustomCollectionIndexer(CollectionIndexer): def __init__(self, config: ColBERTConfig, collection, verbose=2): self.verbose = verbose self.config = config self.rank, self.nranks = self.config.rank, self.config.nranks self.use_gpu = self.config.total_visible_gpus > 0 if self.config.rank == 0 and self.verbose > 1: self.config.help() self.collection = Collection.cast(collection) self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config) if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"): language = detect(self.collection.__getitem__(0)) Run().print_main(f"#> Setting X-MOD language adapters to {language}.") set_xmod_language(self.checkpoint.bert, lang=language) if self.use_gpu: self.checkpoint = self.checkpoint.cuda() self.encoder = CollectionEncoder(config, self.checkpoint) self.saver = IndexSaver(config) print_memory_stats(f'RANK:{self.rank}') #-----------------------------------------------------------------------------------------------------------------# # SEARCHER #-----------------------------------------------------------------------------------------------------------------# class CustomSearcher(Searcher): def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3): self.verbose = verbose if self.verbose > 1: print_memory_stats() initial_config = ColBERTConfig.from_existing(config, Run().config) default_index_root = initial_config.index_root_ index_root = index_root if index_root else default_index_root self.index = os.path.join(index_root, index) self.index_config = ColBERTConfig.load_from_index(self.index) self.checkpoint = checkpoint or self.index_config.checkpoint self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint) self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config) self.collection = Collection.cast(collection or self.config.collection) self.configure(checkpoint=self.checkpoint, collection=self.collection) self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose) if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"): language = detect(self.collection.__getitem__(0)) print_message(f"#> Setting X-MOD language adapters to {language}.") set_xmod_language(self.checkpoint.bert, lang=language) use_gpu = self.config.total_visible_gpus > 0 if use_gpu: self.checkpoint = self.checkpoint.cuda() load_index_with_mmap = self.config.load_index_with_mmap if load_index_with_mmap and use_gpu: raise ValueError(f"Memory-mapped index can only be used with CPU!") self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap) print_memory_stats()