File size: 8,959 Bytes
6a00905
714d682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381bf7c
 
 
714d682
381bf7c
714d682
381bf7c
 
 
 
 
714d682
381bf7c
 
 
 
 
 
714d682
381bf7c
 
714d682
381bf7c
 
 
 
714d682
 
 
 
 
 
6a00905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381bf7c
 
 
 
6a00905
 
 
 
 
 
714d682
 
 
6a00905
714d682
6a00905
714d682
6a00905
714d682
6a00905
714d682
6a00905
714d682
6a00905
714d682
6a00905
0b5f921
6a00905
714d682
6a00905
714d682
6a00905
 
 
0b5f921
6a00905
0b5f921
6a00905
 
 
 
 
714d682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a00905
 
714d682
 
 
 
 
6a00905
 
 
714d682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a00905
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from .configuration_keeper import KeeperConfig

import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModel,
    PreTrainedModel, 
    PretrainedConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)

from typing import Dict
import torch
import numpy as np
from einops import rearrange

class KeeperModelForCausalLM(PreTrainedModel):
    """
    ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf
    We use a dot-product instead of cosine per term (slightly better)
    """
    config_class = KeeperConfig
    base_model_prefix = "keeper_model"

    def __init__(self, cfg, n_cands=8, update_both=False) -> None:
        super().__init__(cfg)

        self.bert = None
        self.llm = None

        # if cfg:
        #     print("Initializing KeeperModelForCausalLM from cfg")
        #     # Inicialización con configuración

        #     self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path'])

        #     bnb_config = BitsAndBytesConfig(
        #         load_in_4bit=True,
        #         bnb_4bit_quant_type="nf4",
        #         bnb_4bit_compute_dtype=torch.bfloat16
        #     )

        #     self.llm = AutoModelForCausalLM.from_pretrained(
        #         cfg.model_config['_name_or_path'],
        #         device_map=cfg.device_map,
        #         torch_dtype=torch.bfloat16,
        #         quantization_config=bnb_config
        #     )

        #     # Almacena kwargs para la serialización y carga futura
        #     # self.init_kwargs = {'cfg': cfg}

        #     print("Initialization complete")
        # else:
        #     # Si cfg no se proporciona, esto se manejará en el método from_pretrained
        #     print("Initializing KeeperTokenizer without cfg")

        self.n_cands = n_cands
        self.update_both = update_both
        print(f"Model n_cands: {self.n_cands}")


    def _load_from_state_dict(self, state_dict, *args, **kwargs):
        super()._load_from_state_dict(state_dict, *args, **kwargs)
        # Ensure CUDA is available
        if torch.cuda.is_available():
            device = torch.device('cuda')
            if "document_retriever_text" in state_dict:
                self.document_retriever_text = state_dict["document_retriever_text"].to(device)
            if "document_retriever_mask" in state_dict:
                self.document_retriever_mask = state_dict["document_retriever_mask"].to(device)
            if "document_retriever_type" in state_dict:
                self.document_retriever_type = state_dict["document_retriever_type"].to(device)
            if "document_model_text" in state_dict:
                self.document_model_text = state_dict["document_model_text"].to(device)
            if "prompt_left" in state_dict:
                self.prompt_left = state_dict["prompt_left"].to(device)
            if "prompt_right" in state_dict:
                self.prompt_right = state_dict["prompt_right"].to(device)
            if "respuesta" in state_dict:
                self.respuesta = state_dict["respuesta"].to(device)
            if "bert" in state_dict:
                self.bert = state_dict["bert"].to(device)
            if "llm" in state_dict:
                self.llm = state_dict["llm"].to(device)
        else:
            # Optionally handle the case where CUDA is not available
            print("CUDA is not available. Tensors will remain on CPU.")


    def generate(self, query: Dict[str, torch.LongTensor], k: int = 3,  max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs):

        query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}

        topk_texts = self.document_extractor(query, k)

        concatenated_texts = torch.cat(topk_texts, dim=0)

        T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1)

        prompt_length = T.shape[1]

        outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample)

        return outputs[0][prompt_length:].unsqueeze(0)

    def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs):

        query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}

        query_vecs = self.forward_representation(query_retriever)

        doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type}

        document_vecs = self.forward_representation(doc_dic, sequence_type="doc")

        self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask)

        k_val = min(k_val, self.score.numel())

        topk_scores, topk_indices = torch.topk(self.score, k_val)

        return [self.document_model_text[i,:] for i in topk_indices[0].tolist()]

    def forward_representation(self,
                               tokens,
                               max_seq_len = 128,
                               sequence_type=None) -> torch.Tensor:

        if sequence_type == "doc":
            if self.update_both:
              with torch.no_grad():
                vecs = self.bert(**tokens)[0]
            else:
              with torch.no_grad():
                with torch.no_grad():
                    vecs = self.bert(**tokens)[0] # assuming a distilbert model here
        else:
          with torch.no_grad():
            vecs = self.bert(**tokens)[0]
        # vecs = self.compressor(vecs)
        return vecs

    def forward_aggregation(self, query_vecs, query_mask, document_vecs, document_mask):

        # query_vecs: B x N x D
        # doc_vecs: (B * k) x N x D

        # Unsqueeze query vector
        _bsz = query_vecs.shape[0]
        n_cands = document_vecs.shape[0] // _bsz
        query_vecs_dup = query_vecs.repeat_interleave(n_cands, dim=0).contiguous()

        score = torch.bmm(query_vecs_dup, document_vecs.transpose(1, 2))
        exp_mask = document_mask.bool().unsqueeze(1).expand(-1, score.shape[1], -1)
        score[~exp_mask] = - 10000

        # max pooling over document dimension
        score = score.max(-1).values
        query_mask_dup = query_mask.repeat_interleave(n_cands, dim=0).contiguous()

        score[~(query_mask_dup.bool())] = 0
        score = rearrange(score.sum(-1), '(b n) -> b n', n=n_cands) # B x k
        return score

    def prompt(self, left_p = None, right_p = None):
        if left_p is None:
          left_p = """ <bos><start_of_turn>user 
          Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa. 
          Segun el siguiente contexto:
-------------------------------
"""
        if right_p is None:
          right_p = """
-------------------------------
- Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'.
- Si encuentras la respuesta puedes copiarla.
- Debes responder solamente en Espanol.
Pregunta: """
        return left_p, right_p

    def save_docs(self, docs: list, tokenizer, max_seq_len=128):
        # Tokenizamos el prompt
        prompt_left, prompt_right = self.prompt()
        prompt_left_output = tokenizer.encode(prompt_left)
        prompt_right_output = tokenizer.encode(prompt_right)

        # Tokenizamos el documento
        doc_outputs = tokenizer.encode(docs, max_length=max_seq_len, padding='max_length', truncation=True)

        # Pasamos los tensores a cuda  (## optimizar: se guardan tensores que no se utilizaran en la gpu)
        doc_outputs = {k: v.to("cuda") for k, v in doc_outputs.items()}
        prompt_left_output = {k: v.to("cuda") for k, v in prompt_left_output.items()}
        prompt_right_output = {k: v.to("cuda") for k, v in prompt_right_output.items()}

        # Tokenizamos la Respuesta
        resp = tokenizer.encode("""
Respuesta: <end_of_turn>
        <start_of_turn>model """)
        resp_model = {k: v.to("cuda") for k, v in resp['tokens_model'].items()}

        # Actualizar el buffer con los vectores de documentos
        self.document_retriever_text = doc_outputs['tokens_retriever']['input_ids']
        self.document_retriever_mask = doc_outputs['tokens_retriever']['attention_mask']
        self.document_retriever_type = doc_outputs['tokens_retriever']['token_type_ids']
        self.document_model_text = doc_outputs['tokens_model']['input_ids']
        # self.document_model_mask = key_outputs['tokens_model']['attention_mask']
        # self.document_model_type = key_outputs['tokens_model']['token_type_ids']
        self.prompt_left = prompt_left_output['tokens_model']['input_ids']
        self.prompt_right = prompt_right_output['tokens_model']['input_ids']
        self.respuesta = resp_model['input_ids']