import logging import os import requests from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from openai import OpenAI from huggingface_hub import snapshot_download from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings class RAG: NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta." # Download the vectorstore from Hugging Face Hub def __init__(self, hf_token, embeddings_model, repo_name,model_name): vectorstore = snapshot_download(repo_name) self.model_name = model_name self.hf_token = hf_token # self.rerank_model = rerank_model # self.rerank_number_contexts = rerank_number_contexts # load vectore store embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'}) self.vectore_store = FAISS.load_local(vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True) logging.info("RAG loaded!") # def rerank_contexts(self, instruction, contexts, number_of_contexts=1): # """ # Rerank the contexts based on their relevance to the given instruction. # """ # rerank_model = self.rerank_model # tokenizer = AutoTokenizer.from_pretrained(rerank_model) # model = AutoModelForSequenceClassification.from_pretrained(rerank_model) # def get_score(query, passage): # """Calculate the relevance score of a passage with respect to a query.""" # inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512) # with torch.no_grad(): # outputs = model(**inputs) # logits = outputs.logits # score = logits.view(-1, ).float() # return score # scores = [get_score(instruction, c[0].page_content) for c in contexts] # combined = list(zip(contexts, scores)) # sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) # sorted_texts, _ = zip(*sorted_combined) # return sorted_texts[:number_of_contexts] def get_context(self, instruction, number_of_contexts=2): """Retrieve the most relevant contexts for a given instruction.""" documentos = self.vectore_store.similarity_search_with_score(instruction, k=4) # documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts) print("Reranked documents") return documentos def predict_dolly(self, instruction, context, model_parameters): api_key = os.getenv("HF_TOKEN") headers = { "Accept" : "application/json", "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n " #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>" payload = { "inputs": query, "parameters": model_parameters } response = requests.post(self.model_name, headers=headers, json=payload) return response.json()[0]["generated_text"].split("###")[-1][8:] def predict_completion(self, instruction, context, model_parameters): client = OpenAI( base_url=os.getenv("MODEL"), api_key=os.getenv("HF_TOKEN") ) query = f"Context:\n{context}\n\nQuestion:\n{instruction}" chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "user", "content": instruction} ], temperature=model_parameters["temperature"], max_tokens=model_parameters["max_new_tokens"], stream=False, stop=["<|im_end|>"], extra_body = { "presence_penalty": model_parameters["repetition_penalty"] - 2, "do_sample": False } ) response = chat_completion.choices[0].message.content return response def beautiful_context(self, docs): text_context = "" full_context = "" source_context = [] for doc in docs: text_context += doc[0].page_content full_context += doc[0].page_content + "\n" full_context += doc[0].metadata["url"] + "\n\n" source_context.append(doc[0].metadata["url"]) return text_context, full_context, source_context def get_response(self, prompt: str, model_parameters: dict) -> str: try: docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"]) text_context, full_context, source = self.beautiful_context(docs) del model_parameters["NUM_CHUNKS"] response = self.predict_completion(prompt, text_context, model_parameters) #response = "Output" if not response: return self.NO_ANSWER_MESSAGE return response, full_context, source except Exception as err: print(err)