|
|
|
|
|
from typing import Dict, Any |
|
|
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig |
|
import transformers |
|
|
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from transformers.models.mistral.modeling_mistral import MistralAttention |
|
from ExtractableMistralAttention import forward |
|
|
|
MistralAttention.forward = forward |
|
|
|
class EndpointHandler(): |
|
def __init__(self, model_dir=''): |
|
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n' |
|
self.max_length = 4096 |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
self.tokenizer.pad_token = '[PAD]' |
|
self.tokenizer.padding_side = 'left' |
|
|
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) |
|
|
|
self.model = AutoModel.from_pretrained( |
|
model_dir, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
attn_implementation="eager", |
|
) |
|
|
|
self.model.eval() |
|
|
|
|
|
def last_token_pool(self, last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
if left_padding: |
|
return last_hidden_states[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_states.shape[0] |
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
|
|
def tokenize(self, text, request_type): |
|
if request_type == 'query': |
|
text = self.instruction + text |
|
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
|
|
def extract_attn_vec(self, model): |
|
return self.model._modules['layers'][-1].self_attn.attn_vec |
|
|
|
|
|
def embed(self, text, request_type): |
|
tokens = self.tokenize(text, request_type) |
|
with torch.no_grad(): |
|
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach() |
|
embedding = self.last_token_pool(output, tokens['attention_mask']) |
|
embedding = F.normalize(embedding, p=2, dim=1) |
|
|
|
attn_vec = self.extract_attn_vec(self.model) |
|
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask']) |
|
attn_vec = F.normalize(attn_vec, p=2, dim=1) |
|
del output, tokens |
|
torch.cuda.empty_cache() |
|
return embedding, attn_vec |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data.pop("inputs", data) |
|
id = inputs.pop("id", inputs) |
|
text = inputs.pop("text", inputs) |
|
request_type = inputs.pop("type", inputs) |
|
|
|
|
|
embeddings, attn_vec = self.embed(text, request_type) |
|
embeddings = embeddings[0].tolist() |
|
attn_vec = attn_vec[0].tolist() |
|
|
|
if request_type == 'query': |
|
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec} |
|
|
|
elif request_type == 'document': |
|
return {"id": id, "embedding": embeddings} |