Spaces:
Running
Running
| import os | |
| import json | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.decomposition import TruncatedSVD | |
| from sklearn.pipeline import Pipeline | |
| DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'raw', 'train.json') | |
| RAG_DIMS = 512 | |
| _pipeline: Pipeline | None = None | |
| _pgvector_ready = False | |
| # ββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _database_url() -> str: | |
| return os.environ.get( | |
| 'DATABASE_URL', | |
| 'postgresql://mlflow:mlflow123@localhost:5432/mlflow_db' | |
| ) | |
| def _connect(): | |
| import psycopg2 | |
| from pgvector.psycopg2 import register_vector | |
| conn = psycopg2.connect(_database_url()) | |
| register_vector(conn) | |
| return conn | |
| def _fit_pipeline(texts: list[str]) -> Pipeline: | |
| pipe = Pipeline([ | |
| ('tfidf', TfidfVectorizer(max_features=5000, stop_words='english', ngram_range=(1, 2))), | |
| ('svd', TruncatedSVD(n_components=RAG_DIMS, random_state=42)), | |
| ]) | |
| pipe.fit(texts) | |
| return pipe | |
| def _embed(pipe: Pipeline, texts: list[str]) -> np.ndarray: | |
| vecs = pipe.transform(texts) | |
| norms = np.linalg.norm(vecs, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| return (vecs / norms).astype(np.float32) | |
| # ββ startup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_rag_data(_main_vectorizer=None) -> None: | |
| global _pipeline, _pgvector_ready | |
| if not os.path.exists(DATA_PATH): | |
| print(f"RAG: {DATA_PATH} not found β retrieval disabled.") | |
| return | |
| try: | |
| with open(DATA_PATH, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| except Exception as e: | |
| print(f"RAG: failed to load data β {e}") | |
| return | |
| statements = [str(item.get('statement', '')) for item in data] | |
| labels = [str(item.get('label', '')) for item in data] | |
| reasons = [str(item.get('reason', '')) for item in data] | |
| # Always (re)fit the pipeline β fast enough on 10K rows | |
| print("RAG: fitting LSA pipeline (TF-IDF 5 000 β SVD 512)β¦") | |
| _pipeline = _fit_pipeline(statements) | |
| db_url = _database_url() | |
| print(f"RAG: connecting to {db_url.split('@')[-1]}β¦") # hide credentials | |
| try: | |
| conn = _connect() | |
| cur = conn.cursor() | |
| cur.execute("SELECT COUNT(*) FROM rag_claims;") | |
| count = cur.fetchone()[0] | |
| if count == 0: | |
| print(f"RAG: inserting {len(statements)} claims into pgvectorβ¦") | |
| vectors = _embed(_pipeline, statements) | |
| from psycopg2.extras import execute_values | |
| rows = [ | |
| (statements[i], labels[i], reasons[i][:500], vectors[i]) | |
| for i in range(len(statements)) | |
| ] | |
| execute_values( | |
| cur, | |
| "INSERT INTO rag_claims (statement, label, reason, embedding) VALUES %s", | |
| rows, | |
| template="(%s, %s, %s, %s)", | |
| ) | |
| conn.commit() | |
| print(f"RAG: {len(rows)} claims indexed in pgvector β") | |
| else: | |
| print(f"RAG: {count} claims already in pgvector β skipping insert.") | |
| cur.close() | |
| conn.close() | |
| _pgvector_ready = True | |
| print("RAG: ready β") | |
| except Exception as e: | |
| print(f"RAG ERROR: {e}") | |
| print("RAG: retrieval disabled β run 'docker compose down -v && docker compose up --build' to reset.") | |
| _pgvector_ready = False | |
| # ββ retrieval βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve_similar(claim: str, predicted_label: str, top_k: int = 3) -> list: | |
| if not _pgvector_ready or _pipeline is None: | |
| return [] | |
| try: | |
| query_vec = _embed(_pipeline, [claim])[0] | |
| conn = _connect() | |
| cur = conn.cursor() | |
| cur.execute( | |
| """ | |
| SELECT statement, label, reason, | |
| 1 - (embedding <=> %s) AS similarity | |
| FROM rag_claims | |
| WHERE label = %s | |
| ORDER BY embedding <=> %s | |
| LIMIT %s | |
| """, | |
| (query_vec, predicted_label, query_vec, top_k), | |
| ) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| return [ | |
| { | |
| 'statement': row[0], | |
| 'label': row[1], | |
| 'reason': row[2], | |
| 'similarity': float(row[3]), | |
| } | |
| for row in rows | |
| ] | |
| except Exception as e: | |
| print(f"RAG: retrieval error β {e}") | |
| return [] | |
| # ββ generation (NVIDIA NIM) βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _NIM_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| _NIM_MODEL = "meta/llama-3.1-8b-instruct" | |
| def _generate_justification(claim: str, predicted_label: str, similar: list) -> str | None: | |
| if not similar: | |
| return None | |
| nim_key = os.environ.get('NVIDIA_API_KEY') | |
| if not nim_key: | |
| print("RAG: NVIDIA_API_KEY not set β justification disabled.") | |
| return None | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(base_url=_NIM_BASE_URL, api_key=nim_key) | |
| except ImportError: | |
| return None | |
| label_fr = { | |
| 'true': 'VRAI', | |
| 'mostly-true': 'MAJORITAIREMENT VRAI', | |
| 'half-true': 'PARTIELLEMENT VRAI', | |
| 'barely-true': 'Γ PEINE VRAI', | |
| 'false': 'FAUX', | |
| 'pants-fire': 'TOTALEMENT FAUX', | |
| }.get(predicted_label, predicted_label.upper()) | |
| examples = [] | |
| for i, s in enumerate(similar[:3], 1): | |
| examples.append( | |
| f"Exemple {i} (similaritΓ©={s['similarity']:.2f}) :\n" | |
| f" DΓ©claration : {s['statement']}\n" | |
| f" Raison : {(s['reason'] or '')[:250]}" | |
| ) | |
| prompt = ( | |
| f"Tu es un assistant de vΓ©rification des faits.\n" | |
| f"DΓ©claration : Β« {claim} Β»\n" | |
| f"Classification : {label_fr} ({predicted_label})\n\n" | |
| f"Exemples similaires (mΓͺme classification) dans la base LIAR :\n\n" | |
| + '\n\n'.join(examples) + | |
| f"\n\nEn 2-3 phrases concises en franΓ§ais, justifie pourquoi cette dΓ©claration est" | |
| f" classΓ©e Β« {predicted_label} Β» en t'appuyant sur les similitudes avec ces exemples." | |
| " RΓ©ponds directement sans titre ni introduction." | |
| ) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=_NIM_MODEL, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=350, | |
| temperature=0.3, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as exc: | |
| print(f"RAG NIM generation error: {exc}") | |
| return None | |
| # ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_rag_result(claim: str, predicted_label: str, _vectorizer=None) -> tuple[list, str | None]: | |
| similar = retrieve_similar(claim, predicted_label) | |
| justification = _generate_justification(claim, predicted_label, similar) | |
| evidence = [ | |
| { | |
| 'statement': s['statement'], | |
| 'label': s['label'], | |
| 'similarity': round(s['similarity'], 3), | |
| 'reason': (s['reason'] or '')[:300] or None, | |
| } | |
| for s in similar | |
| ] | |
| return evidence, justification | |