File size: 3,712 Bytes
8753b76
 
 
 
 
 
 
 
 
 
 
 
dfd303b
 
 
a2cbf75
8753b76
a2cbf75
8753b76
 
6891925
8753b76
 
 
 
 
6891925
8753b76
 
 
4f3226b
8753b76
 
6891925
4f3226b
8753b76
 
 
 
bf13ca8
8753b76
 
 
a2cbf75
8753b76
 
 
 
 
 
 
 
 
 
a2cbf75
 
8753b76
 
 
 
93bd2da
a2cbf75
8753b76
 
a2cbf75
 
8753b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbeb214
 
 
 
a2cbf75
8753b76
a2cbf75
8753b76
 
 
a2cbf75
8753b76
 
a2cbf75
8753b76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# handler file for Huggingface Inference API

from typing import Dict, Any

from peft import PeftModel
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
import transformers


import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn

from transformers.models.mistral.modeling_mistral import MistralAttention
from ExtractableMistralAttention import forward

MistralAttention.forward = forward

class EndpointHandler():
    def __init__(self, model_dir=''):
        self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n'
        self.max_length = 4096
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"


        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
        self.tokenizer.pad_token = '[PAD]'
        self.tokenizer.padding_side = 'left'

        bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)

        self.model = AutoModel.from_pretrained(
            model_dir,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation="eager",
        )
        #self.model = PeftModel.from_pretrained(self.model, model_dir, subfolder='lora')
        self.model.eval()


    def last_token_pool(self, last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


    def tokenize(self, text, request_type):
        if request_type == 'query':
            text = self.instruction + text
        return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device)


    def extract_attn_vec(self, model):
        return self.model._modules['layers'][-1].self_attn.attn_vec


    def embed(self, text, request_type):
        tokens = self.tokenize(text, request_type)
        with torch.no_grad():
            output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach()
            embedding = self.last_token_pool(output, tokens['attention_mask'])
            embedding = F.normalize(embedding, p=2, dim=1)

            attn_vec = self.extract_attn_vec(self.model)
            attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask'])
            attn_vec = F.normalize(attn_vec, p=2, dim=1)
            del output, tokens
            torch.cuda.empty_cache()
            return embedding, attn_vec


    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        id = inputs.pop("id", inputs)
        text = inputs.pop("text", inputs)
        request_type = inputs.pop("type", inputs)


        embeddings, attn_vec = self.embed(text, request_type)
        embeddings = embeddings[0].tolist()
        attn_vec = attn_vec[0].tolist()

        if request_type == 'query':
            return {"id": id, "embedding": embeddings, "attention_vec": attn_vec}
        
        elif request_type == 'document':
            return {"id": id, "embedding": embeddings}