import os import weaviate from sentence_transformers import SentenceTransformer, CrossEncoder from src.llama_cpp_chat_engine import LlamaCPPChatEngine class ChatRagAgent: def __init__(self): # self._chat_engine = LlamaCPPChatEngine("Phi-3-mini-4k-instruct-q4.gguf") self._chat_engine = LlamaCPPChatEngine("Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf") self.n_ctx = self._chat_engine.n_ctx self._vectorizer = SentenceTransformer( "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True ) self._reranker = CrossEncoder( "jinaai/jina-reranker-v1-turbo-en", trust_remote_code=True, ) self._collection = weaviate.connect_to_wcs( cluster_url=os.getenv("WCS_URL"), auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WCS_KEY")), ).collections.get("Collection") def chat(self, messages, user_message): embedding = self._vectorizer.encode(user_message).tolist() docs = self._collection.query.near_vector( near_vector=embedding, limit=10 ) ranks = self._reranker.rank( user_message, [i.properties['answer'] for i in docs.objects], top_k=2, apply_softmax=True ) context = [ f"""\ Question: {docs.objects[rank['corpus_id']].properties['question']} Answer: {docs.objects[rank['corpus_id']].properties['answer']} """ for rank in ranks if rank["score"] > 0.2 ] sources = [ docs.objects[rank['corpus_id']].properties['link'] for rank in ranks if rank["score"] > 0.2 ] return self._chat_engine.chat(messages, user_message, context), sources