antoinelouis commited on
Commit
075dfca
1 Parent(s): 566931f

Create custom.py

Browse files
Files changed (1) hide show
  1. custom.py +148 -0
custom.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langdetect import detect
3
+ import torch.multiprocessing as mp
4
+
5
+ from colbert import Indexer, Searcher
6
+ from colbert.infra import ColBERTConfig, Run
7
+ from colbert.utils.utils import print_message
8
+ from colbert.data.collection import Collection
9
+ from colbert.modeling.checkpoint import Checkpoint
10
+ from colbert.indexing.index_saver import IndexSaver
11
+ from colbert.search.index_storage import IndexScorer
12
+ from colbert.infra.launcher import Launcher, print_memory_stats
13
+ from colbert.indexing.collection_encoder import CollectionEncoder
14
+ from colbert.indexing.collection_indexer import CollectionIndexer
15
+
16
+
17
+ MMARCO_LANGUAGES = {
18
+ 'ar': ('arabic', 'ar_AR'),
19
+ 'de': ('german', 'de_DE'),
20
+ 'en': ('english', 'en_XX'),
21
+ 'es': ('spanish', 'es_XX'),
22
+ 'fr': ('french', 'fr_XX'),
23
+ 'hi': ('hindi', 'hi_IN'),
24
+ 'id': ('indonesian', 'id_ID'),
25
+ 'it': ('italian', 'it_IT'),
26
+ 'ja': ('japanese', 'ja_XX'),
27
+ 'nl': ('dutch', 'nl_XX'),
28
+ 'pt': ('portuguese', 'pt_XX'),
29
+ 'ru': ('russian', 'ru_RU'),
30
+ 'vi': ('vietnamese', 'vi_VN'),
31
+ 'zh': ('chinese', 'zh_CN'),
32
+ }
33
+ MRTYDI_LANGUAGES = {
34
+ 'ar': ('arabic', 'ar_AR'),
35
+ 'bn': ('bengali', 'bn_IN'),
36
+ 'en': ('english', 'en_XX'),
37
+ 'fi': ('finnish', 'fi_FI'),
38
+ 'id': ('indonesian', 'id_ID'),
39
+ 'ja': ('japanese', 'ja_XX'),
40
+ 'ko': ('korean', 'ko_KR'),
41
+ 'ru': ('russian', 'ru_RU'),
42
+ 'sw': ('swahili', 'sw_KE'),
43
+ 'te': ('telugu', 'te_IN'),
44
+ 'th': ('thai', 'th_TH'),
45
+ }
46
+ MIRACL_LANGUAGES = {
47
+ 'ar': ('arabic', 'ar_AR'),
48
+ 'bn': ('bengali', 'bn_IN'),
49
+ 'en': ('english', 'en_XX'),
50
+ 'es': ('spanish', 'es_XX'),
51
+ 'fa': ('persian', 'fa_IR'),
52
+ 'fi': ('finnish', 'fi_FI'),
53
+ 'fr': ('french', 'fr_XX'),
54
+ 'hi': ('hindi', 'hi_IN'),
55
+ 'id': ('indonesian', 'id_ID'),
56
+ 'ja': ('japanese', 'ja_XX'),
57
+ 'ko': ('korean', 'ko_KR'),
58
+ 'ru': ('russian', 'ru_RU'),
59
+ 'sw': ('swahili', 'sw_KE'),
60
+ 'te': ('telugu', 'te_IN'),
61
+ 'th': ('thai', 'th_TH'),
62
+ 'zh': ('chinese', 'zh_CN'),
63
+ }
64
+ ALL_LANGUAGES = {**MMARCO_LANGUAGES, **MRTYDI_LANGUAGES, **MIRACL_LANGUAGES}
65
+
66
+
67
+ def set_xmod_language(model, lang:str):
68
+ """
69
+ Set the default language code for the model. This is used when the language is not specified in the input.
70
+ Source: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/xmod/modeling_xmod.py#L687
71
+ """
72
+ lang = lang.split('-')[0]
73
+ if (value := ALL_LANGUAGES.get(lang)) is not None:
74
+ model.set_default_language(value[1])
75
+ else:
76
+ raise KeyError(f"Language {lang} not supported.")
77
+
78
+ #-----------------------------------------------------------------------------------------------------------------#
79
+ # INDEXER
80
+ #-----------------------------------------------------------------------------------------------------------------#
81
+ class CustomIndexer(Indexer):
82
+ def __launch(self, collection):
83
+ manager = mp.Manager()
84
+ shared_lists = [manager.list() for _ in range(self.config.nranks)]
85
+ shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
86
+ launcher = Launcher(custom_encode)
87
+ launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)
88
+
89
+ def custom_encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
90
+ encoder = CustomCollectionIndexer(config=config, collection=collection, verbose=verbose)
91
+ encoder.run(shared_lists)
92
+
93
+ class CustomCollectionIndexer(CollectionIndexer):
94
+ def __init__(self, config: ColBERTConfig, collection, verbose=2):
95
+ self.verbose = verbose
96
+ self.config = config
97
+ self.rank, self.nranks = self.config.rank, self.config.nranks
98
+ self.use_gpu = self.config.total_visible_gpus > 0
99
+ if self.config.rank == 0 and self.verbose > 1:
100
+ self.config.help()
101
+ self.collection = Collection.cast(collection)
102
+ self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config)
103
+ if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
104
+ language = detect(self.collection.__getitem__(0))
105
+ Run().print_main(f"#> Setting X-MOD language adapters to {language}.")
106
+ set_xmod_language(self.checkpoint.bert, lang=language)
107
+ if self.use_gpu:
108
+ self.checkpoint = self.checkpoint.cuda()
109
+ self.encoder = CollectionEncoder(config, self.checkpoint)
110
+ self.saver = IndexSaver(config)
111
+ print_memory_stats(f'RANK:{self.rank}')
112
+
113
+ #-----------------------------------------------------------------------------------------------------------------#
114
+ # SEARCHER
115
+ #-----------------------------------------------------------------------------------------------------------------#
116
+ class CustomSearcher(Searcher):
117
+ def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3):
118
+ self.verbose = verbose
119
+ if self.verbose > 1:
120
+ print_memory_stats()
121
+
122
+ initial_config = ColBERTConfig.from_existing(config, Run().config)
123
+
124
+ default_index_root = initial_config.index_root_
125
+ index_root = index_root if index_root else default_index_root
126
+ self.index = os.path.join(index_root, index)
127
+ self.index_config = ColBERTConfig.load_from_index(self.index)
128
+
129
+ self.checkpoint = checkpoint or self.index_config.checkpoint
130
+ self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint)
131
+ self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config)
132
+
133
+ self.collection = Collection.cast(collection or self.config.collection)
134
+ self.configure(checkpoint=self.checkpoint, collection=self.collection)
135
+
136
+ self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose)
137
+ if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
138
+ language = detect(self.collection.__getitem__(0))
139
+ print_message(f"#> Setting X-MOD language adapters to {language}.")
140
+ set_xmod_language(self.checkpoint.bert, lang=language)
141
+ use_gpu = self.config.total_visible_gpus > 0
142
+ if use_gpu:
143
+ self.checkpoint = self.checkpoint.cuda()
144
+ load_index_with_mmap = self.config.load_index_with_mmap
145
+ if load_index_with_mmap and use_gpu:
146
+ raise ValueError(f"Memory-mapped index can only be used with CPU!")
147
+ self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap)
148
+ print_memory_stats()