NetsPresso_QA / pyserini /encode /_tct_colbert.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
3.77 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import torch
if torch.cuda.is_available():
from torch.cuda.amp import autocast
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from pyserini.encode import DocumentEncoder, QueryEncoder
from onnxruntime import ExecutionMode, SessionOptions, InferenceSession
class TctColBertDocumentEncoder(DocumentEncoder):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
self.onnx = False
if model_name.endswith('onnx'):
options = SessionOptions()
self.session = InferenceSession(model_name, options)
self.onnx = True
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name[:-5])
else:
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name)
def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs):
if titles is not None:
texts = [f'[CLS] [D] {title} {text}' for title, text in zip(titles, texts)]
else:
texts = ['[CLS] [D] ' + text for text in texts]
inputs = self.tokenizer(
texts,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=False,
return_tensors='pt'
)
if self.onnx:
inputs_onnx = {name: np.atleast_2d(value) for name, value in inputs.items()}
inputs.to(self.device)
outputs, _ = self.session.run(None, inputs_onnx)
outputs = torch.from_numpy(outputs).to(self.device)
embeddings = self._mean_pooling(outputs[:, 4:, :], inputs['attention_mask'][:, 4:])
else:
inputs.to(self.device)
if fp16:
with autocast():
with torch.no_grad():
outputs = self.model(**inputs)
else:
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:])
return embeddings.detach().cpu().numpy()
class TctColBertQueryEncoder(QueryEncoder):
def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
self.device = device
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name)
def encode(self, query: str, **kwargs):
max_length = 36 # hardcode for now
inputs = self.tokenizer(
'[CLS] [Q] ' + query + '[MASK]' * max_length,
max_length=max_length,
truncation=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.detach().cpu().numpy()
return np.average(embeddings[:, 4:, :], axis=-2).flatten()