File size: 3,712 Bytes
8753b76 dfd303b a2cbf75 8753b76 a2cbf75 8753b76 6891925 8753b76 6891925 8753b76 4f3226b 8753b76 6891925 4f3226b 8753b76 bf13ca8 8753b76 a2cbf75 8753b76 a2cbf75 8753b76 93bd2da a2cbf75 8753b76 a2cbf75 8753b76 bbeb214 a2cbf75 8753b76 a2cbf75 8753b76 a2cbf75 8753b76 a2cbf75 8753b76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# handler file for Huggingface Inference API
from typing import Dict, Any
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
from transformers.models.mistral.modeling_mistral import MistralAttention
from ExtractableMistralAttention import forward
MistralAttention.forward = forward
class EndpointHandler():
def __init__(self, model_dir=''):
self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
self.max_length = 4096
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.tokenizer.pad_token = '[PAD]'
self.tokenizer.padding_side = 'left'
bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
self.model = AutoModel.from_pretrained(
model_dir,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
)
#self.model = PeftModel.from_pretrained(self.model, model_dir, subfolder='lora')
self.model.eval()
def last_token_pool(self, last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def tokenize(self, text, request_type):
if request_type == 'query':
text = self.instruction + text
return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)
def extract_attn_vec(self, model):
return self.model._modules['layers'][-1].self_attn.attn_vec
def embed(self, text, request_type):
tokens = self.tokenize(text, request_type)
with torch.no_grad():
output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
embedding = self.last_token_pool(output, tokens['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
attn_vec = self.extract_attn_vec(self.model)
attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
attn_vec = F.normalize(attn_vec, p=2, dim=1)
del output, tokens
torch.cuda.empty_cache()
return embedding, attn_vec
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
id = inputs.pop("id", inputs)
text = inputs.pop("text", inputs)
request_type = inputs.pop("type", inputs)
embeddings, attn_vec = self.embed(text, request_type)
embeddings = embeddings[0].tolist()
attn_vec = attn_vec[0].tolist()
if request_type == 'query':
return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
elif request_type == 'document':
return {"id": id, "embedding": embeddings} |