Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
| from langchain_core.output_parsers import StrOutputParser | |
| from prompt_engineering import build_prompt | |
| # โโ รtat global โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| rag_chain = None | |
| retriever = None | |
| # NOTE: On Hugging Face Spaces, set MISTRAL_API_KEY in your Space's Settings > Secrets. | |
| # Do NOT use python-dotenv or .env files on HF Spaces. | |
| # โโ Helper format docs โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def format_docs(docs) -> str: | |
| """Convertit les documents rรฉcupรฉrรฉs en texte pour le prompt.""" | |
| return "\n\n".join( | |
| f"[Source: {os.path.basename(doc.metadata.get('source', 'Inconnue'))}]\n{doc.page_content}" | |
| for doc in docs | |
| ) | |
| def extract_sources(docs) -> list[str]: | |
| """Formate les sources depuis les mรฉtadonnรฉes des documents.""" | |
| sources = [] | |
| seen = set() | |
| for doc in docs: | |
| source = doc.metadata.get("source", "Inconnue") | |
| page = doc.metadata.get("page") | |
| label = ( | |
| f"{os.path.basename(source)}, page {page + 1}" | |
| if page is not None | |
| else os.path.basename(source) | |
| ) | |
| if label not in seen: | |
| sources.append(label) | |
| seen.add(label) | |
| return sources | |
| def get_confidence(docs_with_scores: list) -> str: | |
| """Calcule le niveau de confiance selon les scores FAISS (distance L2).""" | |
| if not docs_with_scores: | |
| return "low" | |
| avg_score = sum(s for _, s in docs_with_scores) / len(docs_with_scores) | |
| if avg_score < 0.4: | |
| return "high" | |
| elif avg_score < 0.8: | |
| return "medium" | |
| return "low" | |
| # โโ Chargement au dรฉmarrage โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| async def lifespan(app: FastAPI): | |
| global rag_chain, retriever | |
| embedding = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True} | |
| ) | |
| vectorstore = FAISS.load_local( | |
| "faiss_index", | |
| embeddings=embedding, | |
| allow_dangerous_deserialization=True | |
| ) | |
| llm = ChatOpenAI( | |
| base_url="https://api.mistral.ai/v1", | |
| api_key=os.environ["MISTRAL_API_KEY"], # set in HF Spaces Secrets | |
| model_name="mistral-medium" | |
| ) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| rag_chain = ( | |
| { | |
| "context": retriever | RunnableLambda(format_docs), | |
| "question": RunnablePassthrough() | |
| } | |
| | build_prompt() | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| print("โ RAG chain chargรฉe et prรชte.") | |
| yield | |
| print("๐ Arrรชt de l'API.") | |
| # โโ Application FastAPI โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| app = FastAPI( | |
| title="ShopVite RAG API", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # โโ Schรฉmas โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class AskRequest(BaseModel): | |
| question: str | |
| class AskResponse(BaseModel): | |
| answer: str | |
| sources: list[str] | |
| confidence: str # "high" | "medium" | "low" | "out_of_context" | |
| # โโ Routes โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def health(): | |
| if rag_chain is None: | |
| raise HTTPException(status_code=503, detail="RAG chain non initialisรฉe.") | |
| return { | |
| "status": "ok", | |
| "model": "mistral-medium", | |
| "vectorstore": "faiss_index" | |
| } | |
| def ask(body: AskRequest): | |
| question = body.question.strip() | |
| if not question: | |
| raise HTTPException(status_code=400, detail="La question ne peut pas รชtre vide.") | |
| if len(question) > 500: | |
| raise HTTPException(status_code=400, detail="Question trop longue (max 500 caractรจres).") | |
| # Rรฉcupรฉrer les docs et leurs scores FAISS | |
| docs_with_scores = retriever.vectorstore.similarity_search_with_score(question, k=3) | |
| docs = [doc for doc, _ in docs_with_scores] | |
| sources = extract_sources(docs) | |
| # Gรฉnรฉrer la rรฉponse via la chaรฎne RAG | |
| try: | |
| answer = rag_chain.invoke(question) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Erreur LLM : {str(e)}") | |
| # Dรฉtecter question hors contexte | |
| if "HORS_CONTEXTE" in answer: | |
| return AskResponse( | |
| answer=( | |
| "Je suis dรฉsolรฉ, cette information ne figure pas dans mes documents. " | |
| "Pour toute question spรฉcifique, contactez notre support : " | |
| "support@shopvite.fr | 01 23 45 67 89 (lun-ven, 9h-18h)." | |
| ), | |
| sources=[], | |
| confidence="out_of_context" | |
| ) | |
| confidence = get_confidence(docs_with_scores) | |
| return AskResponse( | |
| answer=answer, | |
| sources=sources, | |
| confidence=confidence | |
| ) |