|
import gradio as gr |
|
from langchain_mistralai.chat_models import ChatMistralAI |
|
from langchain.prompts import ChatPromptTemplate |
|
import os |
|
from pathlib import Path |
|
import json |
|
import faiss |
|
import numpy as np |
|
from langchain.schema import Document |
|
import pickle |
|
import re |
|
import requests |
|
from functools import lru_cache |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
import threading |
|
from queue import Queue |
|
import concurrent.futures |
|
from typing import Generator, Tuple |
|
import time |
|
|
|
class OptimizedRAGLoader: |
|
def __init__(self, |
|
docs_folder: str = "./docs", |
|
splits_folder: str = "./splits", |
|
index_folder: str = "./index"): |
|
|
|
self.docs_folder = Path(docs_folder) |
|
self.splits_folder = Path(splits_folder) |
|
self.index_folder = Path(index_folder) |
|
|
|
|
|
for folder in [self.splits_folder, self.index_folder]: |
|
folder.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.splits_path = self.splits_folder / "splits.json" |
|
self.index_path = self.index_folder / "faiss.index" |
|
self.documents_path = self.index_folder / "documents.pkl" |
|
|
|
|
|
self.index = None |
|
self.indexed_documents = None |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.encoder = SentenceTransformer("intfloat/multilingual-e5-large") |
|
self.encoder.to(self.device) |
|
|
|
|
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) |
|
|
|
|
|
self.response_cache = {} |
|
|
|
@lru_cache(maxsize=1000) |
|
def encode(self, text: str): |
|
"""Cached encoding function""" |
|
with torch.no_grad(): |
|
embeddings = self.encoder.encode( |
|
text, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True |
|
) |
|
return embeddings |
|
|
|
def batch_encode(self, texts: list): |
|
"""Batch encoding for multiple texts""" |
|
with torch.no_grad(): |
|
embeddings = self.encoder.encode( |
|
texts, |
|
batch_size=32, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True, |
|
show_progress_bar=False |
|
) |
|
return embeddings |
|
|
|
def load_and_split_texts(self): |
|
if self._splits_exist(): |
|
return self._load_existing_splits() |
|
|
|
documents = [] |
|
futures = [] |
|
|
|
for file_path in self.docs_folder.glob("*.txt"): |
|
future = self.executor.submit(self._process_file, file_path) |
|
futures.append(future) |
|
|
|
for future in concurrent.futures.as_completed(futures): |
|
documents.extend(future.result()) |
|
|
|
self._save_splits(documents) |
|
return documents |
|
|
|
def _process_file(self, file_path): |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
text = file.read() |
|
chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] |
|
|
|
return [ |
|
Document( |
|
page_content=chunk, |
|
metadata={ |
|
'source': file_path.name, |
|
'chunk_id': i, |
|
'total_chunks': len(chunks) |
|
} |
|
) |
|
for i, chunk in enumerate(chunks) |
|
] |
|
|
|
def load_index(self) -> bool: |
|
""" |
|
Charge l'index FAISS et les documents associés s'ils existent |
|
|
|
Returns: |
|
bool: True si l'index a été chargé, False sinon |
|
""" |
|
if not self._index_exists(): |
|
print("Aucun index trouvé.") |
|
return False |
|
|
|
print("Chargement de l'index existant...") |
|
try: |
|
|
|
self.index = faiss.read_index(str(self.index_path)) |
|
|
|
|
|
with open(self.documents_path, 'rb') as f: |
|
self.indexed_documents = pickle.load(f) |
|
|
|
print(f"Index chargé avec {self.index.ntotal} vecteurs") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Erreur lors du chargement de l'index: {e}") |
|
return False |
|
|
|
def create_index(self, documents=None): |
|
if documents is None: |
|
documents = self.load_and_split_texts() |
|
|
|
if not documents: |
|
return False |
|
|
|
texts = [doc.page_content for doc in documents] |
|
embeddings = self.batch_encode(texts) |
|
|
|
dimension = embeddings.shape[1] |
|
self.index = faiss.IndexFlatL2(dimension) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
res = faiss.StandardGpuResources() |
|
self.index = faiss.index_cpu_to_gpu(res, 0, self.index) |
|
|
|
self.index.add(np.array(embeddings).astype('float32')) |
|
self.indexed_documents = documents |
|
|
|
|
|
cpu_index = faiss.index_gpu_to_cpu(self.index) if torch.cuda.is_available() else self.index |
|
faiss.write_index(cpu_index, str(self.index_path)) |
|
|
|
with open(self.documents_path, 'wb') as f: |
|
pickle.dump(documents, f) |
|
|
|
return True |
|
|
|
def _index_exists(self) -> bool: |
|
"""Vérifie si l'index et les documents associés existent""" |
|
return self.index_path.exists() and self.documents_path.exists() |
|
|
|
def get_retriever(self, k: int = 5): |
|
if self.index is None: |
|
if not self.load_index(): |
|
if not self.create_index(): |
|
raise ValueError("Unable to load or create index") |
|
|
|
def retriever_function(query: str) -> list: |
|
|
|
cache_key = f"{query}_{k}" |
|
if cache_key in self.response_cache: |
|
return self.response_cache[cache_key] |
|
|
|
query_embedding = self.encode(query) |
|
|
|
distances, indices = self.index.search( |
|
np.array([query_embedding]).astype('float32'), |
|
k |
|
) |
|
|
|
results = [ |
|
self.indexed_documents[idx] |
|
for idx in indices[0] |
|
if idx != -1 |
|
] |
|
|
|
|
|
self.response_cache[cache_key] = results |
|
return results |
|
|
|
return retriever_function |
|
|
|
|
|
mistral_api_key = os.getenv("mistral_api_key") |
|
llm = ChatMistralAI( |
|
model="mistral-large-latest", |
|
mistral_api_key=mistral_api_key, |
|
temperature=0.1, |
|
streaming=True, |
|
) |
|
|
|
rag_loader = OptimizedRAGLoader() |
|
retriever = rag_loader.get_retriever(k=10) |
|
|
|
|
|
question_cache = {} |
|
|
|
prompt_template = ChatPromptTemplate.from_messages([ |
|
("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة. |
|
استخدم المعلومات التالية للإجابة على السؤال: |
|
|
|
{context} |
|
|
|
إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك. |
|
أجب بشكل موجز ودقيق. |
|
أذكر رقم المادة المصدر. |
|
أذكر اسم ورقم القانون انطلاقا من اسم الملف. |
|
"""), |
|
("human", "{question}") |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_question(question: str) -> Generator[Tuple[str, str], None, None]: |
|
""" |
|
Process the question and yield the answer progressively. |
|
""" |
|
|
|
if question in question_cache: |
|
yield question_cache[question] |
|
|
|
relevant_docs = retriever(question) |
|
context = "\n".join([doc.page_content for doc in relevant_docs]) |
|
|
|
prompt = prompt_template.format_messages( |
|
context=context, |
|
question=question |
|
) |
|
|
|
current_response = "" |
|
for chunk in llm.stream(prompt): |
|
if isinstance(chunk, str): |
|
current_response += chunk |
|
else: |
|
current_response += chunk.content |
|
yield current_response, context |
|
|
|
|
|
custom_css = """ |
|
/* Import Google Fonts - Noto Sans Arabic */ |
|
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Arabic:wght@300;400;500;600;700&display=swap'); |
|
|
|
/* Styles généraux */ |
|
:root { |
|
--primary-color: #4299e1; |
|
--secondary-color: #666666; |
|
--accent-color: #4299E1; |
|
--background-color: #ffffff; |
|
--border-radius: 8px; |
|
--font-family-arabic: 'Noto Sans Arabic', Arial, sans-serif; |
|
} |
|
|
|
/* Style de base */ |
|
body { |
|
font-family: var(--font-family-arabic); |
|
background-color: var(--background-color); |
|
color: var(--primary-color); |
|
} |
|
|
|
/* Styles pour le texte RTL */ |
|
.rtl-text { |
|
text-align: right !important; |
|
direction: rtl !important; |
|
font-family: var(--font-family-arabic) !important; |
|
} |
|
|
|
.rtl-text textarea { |
|
text-align: right !important; |
|
direction: rtl !important; |
|
padding: 1rem !important; |
|
border-radius: var(--border-radius) !important; |
|
border: 1px solid #E2E8F0 !important; |
|
background-color: #ffffff !important; |
|
color: var(--primary-color) !important; |
|
font-size: 1.1rem !important; |
|
line-height: 1.6 !important; |
|
font-family: var(--font-family-arabic) !important; |
|
} |
|
|
|
/* Style du titre */ |
|
.app-title { |
|
font-family: var(--font-family-arabic) !important; |
|
font-size: 2rem !important; |
|
font-weight: 700 !important; |
|
color: white !important; /* Texte en blanc */ |
|
background-color: #3e2b1f !important; /* Fond marron foncé */ |
|
margin-bottom: 1rem !important; |
|
margin-top: 1rem !important; |
|
text-align: center !important; |
|
} |
|
|
|
/* Styles des étiquettes */ |
|
.rtl-text label { |
|
font-family: var(--font-family-arabic) !important; |
|
font-size: 1.2rem !important; |
|
font-weight: 600 !important; |
|
color: #000000 !important; /* Couleur noire pour les étiquettes */ |
|
margin-bottom: 0.5rem !important; |
|
} |
|
|
|
/* Centrer le bouton */ |
|
button.primary-button { |
|
font-family: var(--font-family-arabic) !important; |
|
background-color: var(--accent-color) !important; |
|
color: white !important; |
|
padding: 0.75rem 1.5rem !important; |
|
border-radius: var(--border-radius) !important; |
|
font-weight: 600 !important; |
|
font-size: 1.1rem !important; |
|
transition: all 0.3s ease !important; |
|
width: 200px !important; /* Réduit la largeur du bouton */ |
|
margin: 0 auto !important; /* Centrage horizontal */ |
|
display: block !important; /* Nécessaire pour que le margin auto fonctionne */ |
|
} |
|
|
|
|
|
button.primary-button:hover { |
|
background-color: #3182CE !important; |
|
transform: translateY(-1px) !important; |
|
} |
|
|
|
/* Styles des boîtes de texte */ |
|
.textbox-container { |
|
background-color: #b45f06 !important; |
|
padding: 1.5rem !important; |
|
border-radius: var(--border-radius) !important; |
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; |
|
margin-bottom: 1rem !important; |
|
} |
|
|
|
/* Animation de chargement */ |
|
.loading { |
|
animation: pulse 2s infinite; |
|
} |
|
|
|
@keyframes pulse { |
|
0% { opacity: 1; } |
|
50% { opacity: 0.5; } |
|
100% { opacity: 1; } |
|
} |
|
|
|
/* Style du statut */ |
|
.status-text { |
|
font-family: var(--font-family-arabic) !important; |
|
text-align: center !important; |
|
color: var(--secondary-color) !important; |
|
font-size: 1rem !important; |
|
margin-top: 1rem !important; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as iface: |
|
with gr.Column(elem_classes="container"): |
|
gr.Markdown( |
|
"# نظام الأسئلة والأجوبة الذكي", |
|
elem_classes="app-title rtl-text" |
|
) |
|
|
|
with gr.Column(elem_classes="textbox-container"): |
|
input_text = gr.Textbox( |
|
label="السؤال", |
|
placeholder="اكتب سؤالك هنا...", |
|
lines=1, |
|
elem_classes="rtl-text" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
answer_box = gr.Textbox( |
|
label="الإجابة", |
|
lines=5, |
|
elem_classes="rtl-text textbox-container" |
|
) |
|
|
|
submit_btn = gr.Button( |
|
"إرسال السؤال", |
|
elem_classes="primary-button", |
|
variant="primary" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_response(question): |
|
for chunk_response, _ in process_question(question): |
|
yield chunk_response |
|
time.sleep(0.05) |
|
|
|
submit_btn.click( |
|
fn=stream_response, |
|
inputs=input_text, |
|
outputs=answer_box, |
|
api_name="predict", |
|
queue=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
max_threads=3, |
|
show_error=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|