Spaces:
Paused
Paused
File size: 6,426 Bytes
8765003 180c827 8765003 180c827 8765003 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import pickle
import textwrap
import logging
from typing import List
import faiss
import numpy as np
from llama_cpp import Llama
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sentence_transformers.util import cos_sim
# === Logger configuration ===
logger = logging.getLogger("RAGEngine")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
MAX_TOKENS = 512
class RAGEngine:
def __init__(self, model_path: str, vector_path: str, index_path: str, model_threads: int = 4):
logger.info("📦 Initialisation du moteur RAG...")
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Warmup pour éviter le temps de latence initial
try:
self.llm("Bonjour", max_tokens=1)
except Exception as e:
logger.warning(f"Warmup LLM échoué : {e}")
logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
with open(vector_path, "rb") as f:
chunk_texts = pickle.load(f)
nodes = [TextNode(text=chunk) for chunk in chunk_texts]
faiss_index = faiss.read_index(index_path)
vector_store = FaissVectorStore(faiss_index=faiss_index)
self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
logger.info("✅ Moteur RAG initialisé avec succès.")
def reformulate_with_context(self, question: str, context_sample: str) -> str:
logger.info("🔁 Reformulation de la question avec contexte...")
prompt = f"""Tu es un assistant expert en machine learning. Ton rôle est de reformuler les questions utilisateur en tenant compte du contexte ci-dessous, extrait d’un rapport technique sur un projet de reconnaissance de maladies de plantes.
Ta mission est de transformer une question vague ou floue en une question précise et adaptée au contenu du rapport. Ne donne pas une interprétation hors sujet. Ne reformule pas en termes de produits commerciaux.
Contexte :
{context_sample}
Question initiale : {question}
Question reformulée :"""
output = self.llm(prompt, max_tokens=128, stop=[""], stream=False)
reformulated = output["choices"][0]["text"].strip()
logger.info(f"📝 Reformulée avec contexte : {reformulated}")
return reformulated
def get_adaptive_top_k(self, question: str) -> int:
q = question.lower()
if len(q.split()) <= 7:
return 8
elif any(w in q for w in ["liste", "résume", "quels sont", "explique", "comment"]):
return 10
return 8
def rerank_nodes(self, question: str, retrieved_nodes, top_k: int = 3):
logger.info(f"🔍 Re-ranking des {len(retrieved_nodes)} chunks pour la question : « {question} »")
q_emb = self.embed_model.get_query_embedding(question)
if q_emb is None:
logger.warning("Embedding de la wuestion introuvable")
return retrieved_nodes[:top_k]
scored_nodes = []
for node in retrieved_nodes:
chunk_text = node.get_content()
chunk_emb = self.embed_model.get_text_embedding(chunk_text)
score = float(np.dot(q_emb, chunk_emb))
scored_nodes.append((score, node))
ranked_nodes = sorted(scored_nodes, key=lambda x: x[0], reverse=True)
logger.info("📊 Chunks les plus pertinents :")
for i, (score, node) in enumerate(ranked_nodes[:top_k]):
chunk_preview = textwrap.shorten(node.get_content().replace("\n", " "), width=100)
logger.info(f"#{i+1} | Score: {score:.4f} | {chunk_preview}")
return [n for _, n in ranked_nodes[:top_k]]
def retrieve_context(self, question: str, top_k: int = 3):
logger.info(f"📥 Récupération du contexte...")
retriever = self.index.as_retriever(similarity_top_k=top_k)
retrieved_nodes = retriever.retrieve(question)
reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
return context, reranked_nodes
def ask(self, question_raw: str) -> str:
logger.info(f"💬 Question reçue : {question_raw}")
context_sample, _ = self.retrieve_context(question_raw, top_k=3)
reformulated = self.reformulate_with_context(question_raw, context_sample)
logger.info(f"📝 Question reformulée : {reformulated}")
top_k = self.get_adaptive_top_k(reformulated)
context, _ = self.retrieve_context(reformulated, top_k)
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
Contexte :
{context}
Question : {reformulated}
### Réponse:"""
output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
response = output["choices"][0]["text"].strip().split("###")[0]
logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
return response
def ask_stream(self, question: str):
logger.info(f"💬 [Stream] Question reçue : {question}")
top_k = self.get_adaptive_top_k(question)
context, _ = self.retrieve_context(question, top_k)
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
Contexte :
{context}
Question : {question}
### Réponse:"""
logger.info("📡 Début du streaming de la réponse...")
stream = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=True)
for chunk in stream:
print(chunk["choices"][0]["text"], end="", flush=True)
|