enpaiva commited on
Commit
6a00905
1 Parent(s): f19652c

Update model_keeper.py

Browse files
Files changed (1) hide show
  1. model_keeper.py +50 -32
model_keeper.py CHANGED
@@ -1,4 +1,4 @@
1
- from configuration_keeper import KeeperConfig
2
 
3
  import torch
4
  from transformers import (
@@ -61,45 +61,64 @@ class KeeperModelForCausalLM(PreTrainedModel):
61
  self.update_both = update_both
62
  print(f"Model n_cands: {self.n_cands}")
63
 
64
- # Inicializar buffers vacíos para document_vecs y document_mask
65
- self.register_buffer('document_retriever_text', torch.empty(0, dtype=torch.long))
66
- self.register_buffer('document_retriever_mask', torch.empty(0, dtype=torch.long))
67
- self.register_buffer('document_retriever_type', torch.empty(0, dtype=torch.long))
68
- self.register_buffer('document_model_text', torch.empty(0, dtype=torch.long))
69
- # self.register_buffer('document_model_mask', torch.empty(0, dtype=torch.long))
70
- # self.register_buffer('document_model_type', torch.empty(0, dtype=torch.long))
71
- self.register_buffer('prompt_left', torch.empty(0, dtype=torch.long))
72
- self.register_buffer('prompt_right', torch.empty(0, dtype=torch.long))
73
- self.register_buffer('respuesta', torch.empty(0, dtype=torch.long))
74
 
75
- def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}
78
  query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}
79
 
80
- query_vecs = self.forward_representation(query_retriever)
81
 
82
- doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type}
83
 
84
- document_vecs = self.forward_representation(doc_dic, sequence_type="doc")
85
 
86
- self.score = self.forward_aggregation(query_vecs, query['tokens_model']["attention_mask"], document_vecs, self.document_retriever_mask)
87
 
88
- k = min(k, self.score.numel())
89
 
90
- topk_scores, topk_indices = torch.topk(self.score, k)
91
 
92
- topk_texts = [self.document_model_text[i] for i in topk_indices[0].tolist()]
93
 
94
- concatenated_texts = torch.cat(topk_texts, dim=0)
95
 
96
- T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1)
97
 
98
- prompt_length = T.shape[1]
 
 
99
 
100
- outputs = self.llm.generate(input_ids=T, max_new_tokens=256, repetition_penalty=1.15)
101
 
102
- return outputs[0][prompt_length:].unsqueeze(0)
 
 
 
 
103
 
104
  def forward_representation(self,
105
  tokens,
@@ -145,16 +164,16 @@ class KeeperModelForCausalLM(PreTrainedModel):
145
  def prompt(self, left_p = None, right_p = None):
146
  if left_p is None:
147
  left_p = """ <bos><start_of_turn>user
148
- Eres un experto en cultura paraguaya que responde segun el contexto:
 
149
  -------------------------------
150
  """
151
  if right_p is None:
152
  right_p = """
153
  -------------------------------
154
- - Debes responder solamente en Espanol
155
- - No utilices conocimientos previos.
156
- - Responde de forma clara, amable y concisa.
157
-
158
  Pregunta: """
159
  return left_p, right_p
160
 
@@ -187,5 +206,4 @@ Respuesta: <end_of_turn>
187
  # self.document_model_type = key_outputs['tokens_model']['token_type_ids']
188
  self.prompt_left = prompt_left_output['tokens_model']['input_ids']
189
  self.prompt_right = prompt_right_output['tokens_model']['input_ids']
190
- self.respuesta = resp_model['input_ids']
191
-
 
1
+ from .configuration_keeper import KeeperConfig
2
 
3
  import torch
4
  from transformers import (
 
61
  self.update_both = update_both
62
  print(f"Model n_cands: {self.n_cands}")
63
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def _load_from_state_dict(self, state_dict, *args, **kwargs):
66
+ super()._load_from_state_dict(state_dict, *args, **kwargs)
67
+ # Ensure CUDA is available
68
+ if torch.cuda.is_available():
69
+ device = torch.device('cuda')
70
+ if "document_retriever_text" in state_dict:
71
+ self.document_retriever_text = state_dict["document_retriever_text"].to(device)
72
+ if "document_retriever_mask" in state_dict:
73
+ self.document_retriever_mask = state_dict["document_retriever_mask"].to(device)
74
+ if "document_retriever_type" in state_dict:
75
+ self.document_retriever_type = state_dict["document_retriever_type"].to(device)
76
+ if "document_model_text" in state_dict:
77
+ self.document_model_text = state_dict["document_model_text"].to(device)
78
+ if "prompt_left" in state_dict:
79
+ self.prompt_left = state_dict["prompt_left"].to(device)
80
+ if "prompt_right" in state_dict:
81
+ self.prompt_right = state_dict["prompt_right"].to(device)
82
+ if "respuesta" in state_dict:
83
+ self.respuesta = state_dict["respuesta"].to(device)
84
+ else:
85
+ # Optionally handle the case where CUDA is not available
86
+ print("CUDA is not available. Tensors will remain on CPU.")
87
+
88
+
89
+ def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs):
90
 
 
91
  query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}
92
 
93
+ topk_texts = self.document_extractor(query, k)
94
 
95
+ concatenated_texts = torch.cat(topk_texts, dim=0)
96
 
97
+ T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1)
98
 
99
+ prompt_length = T.shape[1]
100
 
101
+ outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample)
102
 
103
+ return outputs[0][prompt_length:].unsqueeze(0)
104
 
105
+ def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs):
106
 
107
+ query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}
108
 
109
+ query_vecs = self.forward_representation(query_retriever)
110
 
111
+ doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type}
112
+
113
+ document_vecs = self.forward_representation(doc_dic, sequence_type="doc")
114
 
115
+ self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask)
116
 
117
+ k_val = min(k_val, self.score.numel())
118
+
119
+ topk_scores, topk_indices = torch.topk(self.score, k_val)
120
+
121
+ return [self.document_model_text[i,:] for i in topk_indices[0].tolist()]
122
 
123
  def forward_representation(self,
124
  tokens,
 
164
  def prompt(self, left_p = None, right_p = None):
165
  if left_p is None:
166
  left_p = """ <bos><start_of_turn>user
167
+ Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa.
168
+ Segun el siguiente contexto:
169
  -------------------------------
170
  """
171
  if right_p is None:
172
  right_p = """
173
  -------------------------------
174
+ - Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'.
175
+ - Si encuentras la respuesta puedes copiarla.
176
+ - Debes responder solamente en Espanol.
 
177
  Pregunta: """
178
  return left_p, right_p
179
 
 
206
  # self.document_model_type = key_outputs['tokens_model']['token_type_ids']
207
  self.prompt_left = prompt_left_output['tokens_model']['input_ids']
208
  self.prompt_right = prompt_right_output['tokens_model']['input_ids']
209
+ self.respuesta = resp_model['input_ids']