# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer from pyserini.encode import DocumentEncoder, QueryEncoder class DprDocumentEncoder(DocumentEncoder): def __init__(self, model_name, tokenizer_name=None, device='cuda:0'): self.device = device self.model = DPRContextEncoder.from_pretrained(model_name) self.model.to(self.device) self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(tokenizer_name or model_name) def encode(self, texts, titles=None, max_length=256, **kwargs): if titles: inputs = self.tokenizer( titles, text_pair=texts, max_length=max_length, padding='longest', truncation=True, add_special_tokens=True, return_tensors='pt' ) else: inputs = self.tokenizer( texts, max_length=max_length, padding='longest', truncation=True, add_special_tokens=True, return_tensors='pt' ) inputs.to(self.device) return self.model(inputs["input_ids"]).pooler_output.detach().cpu().numpy() class DprQueryEncoder(QueryEncoder): def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): self.device = device self.model = DPRQuestionEncoder.from_pretrained(model_name) self.model.to(self.device) self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or model_name) def encode(self, query: str, **kwargs): input_ids = self.tokenizer(query, return_tensors='pt') input_ids.to(self.device) embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() return embeddings.flatten()