from .configuration_keeper import KeeperConfig import torch from transformers import ( AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, BitsAndBytesConfig ) from typing import Dict import torch import numpy as np from einops import rearrange class KeeperModelForCausalLM(PreTrainedModel): """ ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf We use a dot-product instead of cosine per term (slightly better) """ config_class = KeeperConfig base_model_prefix = "keeper_model" def __init__(self, cfg, n_cands=8, update_both=False) -> None: super().__init__(cfg) self.bert = None self.llm = None # if cfg: # print("Initializing KeeperModelForCausalLM from cfg") # # Inicialización con configuración # self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path']) # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_compute_dtype=torch.bfloat16 # ) # self.llm = AutoModelForCausalLM.from_pretrained( # cfg.model_config['_name_or_path'], # device_map=cfg.device_map, # torch_dtype=torch.bfloat16, # quantization_config=bnb_config # ) # # Almacena kwargs para la serialización y carga futura # # self.init_kwargs = {'cfg': cfg} # print("Initialization complete") # else: # # Si cfg no se proporciona, esto se manejará en el método from_pretrained # print("Initializing KeeperTokenizer without cfg") self.n_cands = n_cands self.update_both = update_both print(f"Model n_cands: {self.n_cands}") def _load_from_state_dict(self, state_dict, *args, **kwargs): super()._load_from_state_dict(state_dict, *args, **kwargs) # Ensure CUDA is available if torch.cuda.is_available(): device = torch.device('cuda') if "document_retriever_text" in state_dict: self.document_retriever_text = state_dict["document_retriever_text"].to(device) if "document_retriever_mask" in state_dict: self.document_retriever_mask = state_dict["document_retriever_mask"].to(device) if "document_retriever_type" in state_dict: self.document_retriever_type = state_dict["document_retriever_type"].to(device) if "document_model_text" in state_dict: self.document_model_text = state_dict["document_model_text"].to(device) if "prompt_left" in state_dict: self.prompt_left = state_dict["prompt_left"].to(device) if "prompt_right" in state_dict: self.prompt_right = state_dict["prompt_right"].to(device) if "respuesta" in state_dict: self.respuesta = state_dict["respuesta"].to(device) if "bert" in state_dict: self.bert = state_dict["bert"].to(device) if "llm" in state_dict: self.llm = state_dict["llm"].to(device) else: # Optionally handle the case where CUDA is not available print("CUDA is not available. Tensors will remain on CPU.") def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs): query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()} topk_texts = self.document_extractor(query, k) concatenated_texts = torch.cat(topk_texts, dim=0) T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1) prompt_length = T.shape[1] outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample) return outputs[0][prompt_length:].unsqueeze(0) def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs): query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()} query_vecs = self.forward_representation(query_retriever) doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type} document_vecs = self.forward_representation(doc_dic, sequence_type="doc") self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask) k_val = min(k_val, self.score.numel()) topk_scores, topk_indices = torch.topk(self.score, k_val) return [self.document_model_text[i,:] for i in topk_indices[0].tolist()] def forward_representation(self, tokens, max_seq_len = 128, sequence_type=None) -> torch.Tensor: if sequence_type == "doc": if self.update_both: with torch.no_grad(): vecs = self.bert(**tokens)[0] else: with torch.no_grad(): with torch.no_grad(): vecs = self.bert(**tokens)[0] # assuming a distilbert model here else: with torch.no_grad(): vecs = self.bert(**tokens)[0] # vecs = self.compressor(vecs) return vecs def forward_aggregation(self, query_vecs, query_mask, document_vecs, document_mask): # query_vecs: B x N x D # doc_vecs: (B * k) x N x D # Unsqueeze query vector _bsz = query_vecs.shape[0] n_cands = document_vecs.shape[0] // _bsz query_vecs_dup = query_vecs.repeat_interleave(n_cands, dim=0).contiguous() score = torch.bmm(query_vecs_dup, document_vecs.transpose(1, 2)) exp_mask = document_mask.bool().unsqueeze(1).expand(-1, score.shape[1], -1) score[~exp_mask] = - 10000 # max pooling over document dimension score = score.max(-1).values query_mask_dup = query_mask.repeat_interleave(n_cands, dim=0).contiguous() score[~(query_mask_dup.bool())] = 0 score = rearrange(score.sum(-1), '(b n) -> b n', n=n_cands) # B x k return score def prompt(self, left_p = None, right_p = None): if left_p is None: left_p = """ user Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa. Segun el siguiente contexto: ------------------------------- """ if right_p is None: right_p = """ ------------------------------- - Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'. - Si encuentras la respuesta puedes copiarla. - Debes responder solamente en Espanol. Pregunta: """ return left_p, right_p def save_docs(self, docs: list, tokenizer, max_seq_len=128): # Tokenizamos el prompt prompt_left, prompt_right = self.prompt() prompt_left_output = tokenizer.encode(prompt_left) prompt_right_output = tokenizer.encode(prompt_right) # Tokenizamos el documento doc_outputs = tokenizer.encode(docs, max_length=max_seq_len, padding='max_length', truncation=True) # Pasamos los tensores a cuda (## optimizar: se guardan tensores que no se utilizaran en la gpu) doc_outputs = {k: v.to("cuda") for k, v in doc_outputs.items()} prompt_left_output = {k: v.to("cuda") for k, v in prompt_left_output.items()} prompt_right_output = {k: v.to("cuda") for k, v in prompt_right_output.items()} # Tokenizamos la Respuesta resp = tokenizer.encode(""" Respuesta: model """) resp_model = {k: v.to("cuda") for k, v in resp['tokens_model'].items()} # Actualizar el buffer con los vectores de documentos self.document_retriever_text = doc_outputs['tokens_retriever']['input_ids'] self.document_retriever_mask = doc_outputs['tokens_retriever']['attention_mask'] self.document_retriever_type = doc_outputs['tokens_retriever']['token_type_ids'] self.document_model_text = doc_outputs['tokens_model']['input_ids'] # self.document_model_mask = key_outputs['tokens_model']['attention_mask'] # self.document_model_type = key_outputs['tokens_model']['token_type_ids'] self.prompt_left = prompt_left_output['tokens_model']['input_ids'] self.prompt_right = prompt_right_output['tokens_model']['input_ids'] self.respuesta = resp_model['input_ids']