Update model_keeper.py
Browse files- model_keeper.py +50 -32
model_keeper.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from configuration_keeper import KeeperConfig
|
2 |
|
3 |
import torch
|
4 |
from transformers import (
|
@@ -61,45 +61,64 @@ class KeeperModelForCausalLM(PreTrainedModel):
|
|
61 |
self.update_both = update_both
|
62 |
print(f"Model n_cands: {self.n_cands}")
|
63 |
|
64 |
-
# Inicializar buffers vacíos para document_vecs y document_mask
|
65 |
-
self.register_buffer('document_retriever_text', torch.empty(0, dtype=torch.long))
|
66 |
-
self.register_buffer('document_retriever_mask', torch.empty(0, dtype=torch.long))
|
67 |
-
self.register_buffer('document_retriever_type', torch.empty(0, dtype=torch.long))
|
68 |
-
self.register_buffer('document_model_text', torch.empty(0, dtype=torch.long))
|
69 |
-
# self.register_buffer('document_model_mask', torch.empty(0, dtype=torch.long))
|
70 |
-
# self.register_buffer('document_model_type', torch.empty(0, dtype=torch.long))
|
71 |
-
self.register_buffer('prompt_left', torch.empty(0, dtype=torch.long))
|
72 |
-
self.register_buffer('prompt_right', torch.empty(0, dtype=torch.long))
|
73 |
-
self.register_buffer('respuesta', torch.empty(0, dtype=torch.long))
|
74 |
|
75 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}
|
78 |
query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
83 |
|
84 |
-
|
85 |
|
86 |
-
|
87 |
|
88 |
-
|
89 |
|
90 |
-
|
91 |
|
92 |
-
|
93 |
|
94 |
-
|
95 |
|
96 |
-
|
97 |
|
98 |
-
|
|
|
|
|
99 |
|
100 |
-
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def forward_representation(self,
|
105 |
tokens,
|
@@ -145,16 +164,16 @@ class KeeperModelForCausalLM(PreTrainedModel):
|
|
145 |
def prompt(self, left_p = None, right_p = None):
|
146 |
if left_p is None:
|
147 |
left_p = """ <bos><start_of_turn>user
|
148 |
-
Eres un experto en cultura paraguaya que responde
|
|
|
149 |
-------------------------------
|
150 |
"""
|
151 |
if right_p is None:
|
152 |
right_p = """
|
153 |
-------------------------------
|
154 |
-
-
|
155 |
-
-
|
156 |
-
-
|
157 |
-
|
158 |
Pregunta: """
|
159 |
return left_p, right_p
|
160 |
|
@@ -187,5 +206,4 @@ Respuesta: <end_of_turn>
|
|
187 |
# self.document_model_type = key_outputs['tokens_model']['token_type_ids']
|
188 |
self.prompt_left = prompt_left_output['tokens_model']['input_ids']
|
189 |
self.prompt_right = prompt_right_output['tokens_model']['input_ids']
|
190 |
-
self.respuesta = resp_model['input_ids']
|
191 |
-
|
|
|
1 |
+
from .configuration_keeper import KeeperConfig
|
2 |
|
3 |
import torch
|
4 |
from transformers import (
|
|
|
61 |
self.update_both = update_both
|
62 |
print(f"Model n_cands: {self.n_cands}")
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
def _load_from_state_dict(self, state_dict, *args, **kwargs):
|
66 |
+
super()._load_from_state_dict(state_dict, *args, **kwargs)
|
67 |
+
# Ensure CUDA is available
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
device = torch.device('cuda')
|
70 |
+
if "document_retriever_text" in state_dict:
|
71 |
+
self.document_retriever_text = state_dict["document_retriever_text"].to(device)
|
72 |
+
if "document_retriever_mask" in state_dict:
|
73 |
+
self.document_retriever_mask = state_dict["document_retriever_mask"].to(device)
|
74 |
+
if "document_retriever_type" in state_dict:
|
75 |
+
self.document_retriever_type = state_dict["document_retriever_type"].to(device)
|
76 |
+
if "document_model_text" in state_dict:
|
77 |
+
self.document_model_text = state_dict["document_model_text"].to(device)
|
78 |
+
if "prompt_left" in state_dict:
|
79 |
+
self.prompt_left = state_dict["prompt_left"].to(device)
|
80 |
+
if "prompt_right" in state_dict:
|
81 |
+
self.prompt_right = state_dict["prompt_right"].to(device)
|
82 |
+
if "respuesta" in state_dict:
|
83 |
+
self.respuesta = state_dict["respuesta"].to(device)
|
84 |
+
else:
|
85 |
+
# Optionally handle the case where CUDA is not available
|
86 |
+
print("CUDA is not available. Tensors will remain on CPU.")
|
87 |
+
|
88 |
+
|
89 |
+
def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs):
|
90 |
|
|
|
91 |
query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}
|
92 |
|
93 |
+
topk_texts = self.document_extractor(query, k)
|
94 |
|
95 |
+
concatenated_texts = torch.cat(topk_texts, dim=0)
|
96 |
|
97 |
+
T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1)
|
98 |
|
99 |
+
prompt_length = T.shape[1]
|
100 |
|
101 |
+
outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample)
|
102 |
|
103 |
+
return outputs[0][prompt_length:].unsqueeze(0)
|
104 |
|
105 |
+
def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs):
|
106 |
|
107 |
+
query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}
|
108 |
|
109 |
+
query_vecs = self.forward_representation(query_retriever)
|
110 |
|
111 |
+
doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type}
|
112 |
+
|
113 |
+
document_vecs = self.forward_representation(doc_dic, sequence_type="doc")
|
114 |
|
115 |
+
self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask)
|
116 |
|
117 |
+
k_val = min(k_val, self.score.numel())
|
118 |
+
|
119 |
+
topk_scores, topk_indices = torch.topk(self.score, k_val)
|
120 |
+
|
121 |
+
return [self.document_model_text[i,:] for i in topk_indices[0].tolist()]
|
122 |
|
123 |
def forward_representation(self,
|
124 |
tokens,
|
|
|
164 |
def prompt(self, left_p = None, right_p = None):
|
165 |
if left_p is None:
|
166 |
left_p = """ <bos><start_of_turn>user
|
167 |
+
Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa.
|
168 |
+
Segun el siguiente contexto:
|
169 |
-------------------------------
|
170 |
"""
|
171 |
if right_p is None:
|
172 |
right_p = """
|
173 |
-------------------------------
|
174 |
+
- Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'.
|
175 |
+
- Si encuentras la respuesta puedes copiarla.
|
176 |
+
- Debes responder solamente en Espanol.
|
|
|
177 |
Pregunta: """
|
178 |
return left_p, right_p
|
179 |
|
|
|
206 |
# self.document_model_type = key_outputs['tokens_model']['token_type_ids']
|
207 |
self.prompt_left = prompt_left_output['tokens_model']['input_ids']
|
208 |
self.prompt_right = prompt_right_output['tokens_model']['input_ids']
|
209 |
+
self.respuesta = resp_model['input_ids']
|
|