proofly / model.py
Pragthedon's picture
Fix: Backend OOM crashes via Vector Cache and worker reduction
c7893c0
# ==========================================
# IMPORTS
# ==========================================
import os
import requests
import faiss
import numpy as np
import urllib.parse
from bs4 import BeautifulSoup
import feedparser
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# Import shared config and database layer
from project.config import (
FAISS_FILE, NEWS_API_KEY,
USER_AGENT, SENTENCE_TRANSFORMER_MODEL, NLI_MODEL as NLI_MODEL_NAME
)
from project.database import init_db, clear_db, save_evidence, load_all_evidence
from knowledge_base import KNOWLEDGE_BASE
# ==========================================
# MODEL LOADING
# ==========================================
embed_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
nli_model = pipeline(
"text-classification",
model=NLI_MODEL_NAME
)
# ==========================================
# RELEVANCE CHECK
# ==========================================
def is_relevant(claim_emb, text, threshold=0.15):
"""Encodes text and checks similarity against claim.
Returns (bool_is_relevant, embedding_as_list)."""
emb = embed_model.encode([text], normalize_embeddings=True)
sim = float(np.dot(claim_emb, emb[0]))
print(f"[DEBUG] Checking relevance for: '{text[:50]}...' Score: {sim:.4f}")
return sim >= threshold, emb[0].tolist()
def get_search_query(claim):
stop_words = set(["is", "am", "are", "was", "were", "be", "been", "being",
"the", "a", "an", "and", "but", "or", "on", "in", "with", "of", "to", "for",
"he", "she", "it", "they", "we", "i", "you", "that", "this", "these", "those",
"have", "has", "had", "do", "does", "did", "not", "no", "yes", "from"])
words = [w for w in claim.split() if w.lower() not in stop_words]
# Return top words to form a potent query (e.g. "modi president india")
return " ".join(words[:5])
# ==========================================
# RSS FETCH
# ==========================================
def fetch_rss(claim_emb):
print("[RSS] Fetching...")
feeds = [
"http://feeds.bbci.co.uk/news/rss.xml",
"http://rss.cnn.com/rss/edition.rss",
"https://www.aljazeera.com/xml/rss/all.xml",
"https://www.theguardian.com/world/rss",
"https://rss.nytimes.com/services/xml/rss/nyt/World.xml",
"https://timesofindia.indiatimes.com/rss.cms",
"https://www.hindustantimes.com/feeds/rss/topstories.rss",
"https://cfo.economictimes.indiatimes.com/rss",
"https://www.business-standard.com/rss/",
"https://www.thehindu.com/news/national/feeder/default.rss",
"https://indianexpress.com/section/india/feed/",
"https://feeds.feedburner.com/ndtvnews-top-stories"
]
count = 0
for url in feeds:
try:
feed = feedparser.parse(url)
print(f"[RSS] Parsed {url}, found {len(feed.entries)} entries")
for entry in feed.entries[:5]:
title = entry.title
if title:
relevant, emb = is_relevant(claim_emb, title)
if relevant:
save_evidence(title, "RSS", embedding=emb)
count += 1
except Exception as e:
print(f"[RSS] Error parsing {url}: {e}")
print(f"[RSS] Saved {count} items.")
# ==========================================
# GDELT FETCH
# ==========================================
def fetch_gdelt(claim, claim_emb):
print("[GDELT] Fetching...")
search_query = get_search_query(claim)
url = "https://api.gdeltproject.org/api/v2/doc/doc"
params = {
"query": search_query,
"mode": "ArtList",
"format": "json",
"maxrecords": 5
}
added = 0
try:
r = requests.get(url, params=params, timeout=10)
r.raise_for_status()
data = r.json()
articles = data.get("articles", [])
print(f"[GDELT] Found {len(articles)} articles")
for art in articles:
title = art.get("title", "")
if title:
relevant, emb = is_relevant(claim_emb, title)
if relevant:
save_evidence(title, "GDELT", embedding=emb)
added += 1
except Exception as e:
print("[WARNING] GDELT failed:", e)
print(f"[GDELT] Saved {added} items.")
return added
# ==========================================
# NEWS API FETCH
# ==========================================
def fetch_newsapi(claim, claim_emb):
print("[NewsAPI] Fetching...")
if not NEWS_API_KEY:
print("[WARNING] NEWS_API_KEY is not set in .env — skipping NewsAPI.")
return 0
url = "https://newsapi.org/v2/everything"
search_query = get_search_query(claim)
params = {
"q": search_query,
"apiKey": NEWS_API_KEY,
"language": "en",
"sortBy": "relevancy",
"pageSize": 5
}
added = 0
try:
r = requests.get(url, params=params, timeout=10)
data = r.json()
if r.status_code != 200:
print(f"[WARNING] NewsAPI Error: {data.get('message', 'Unknown error')}")
return 0
articles = data.get("articles", [])
print(f"[NewsAPI] Found {len(articles)} articles")
for art in articles:
title = art.get("title", "")
description = art.get("description", "") or ""
content = f"{title}. {description}".strip(". ")
if content:
relevant, emb = is_relevant(claim_emb, content, threshold=0.05)
if relevant:
save_evidence(content, f"NewsAPI: {art.get('source', {}).get('name', 'Unknown')}", embedding=emb)
added += 1
except Exception as e:
print("[WARNING] NewsAPI failed:", e)
print(f"[NewsAPI] Saved {added} items.")
return added
# ==========================================
# WIKIPEDIA (REST API)
# ==========================================
def fetch_wikipedia(claim):
print("[Wikipedia] Fetching...")
search_query = get_search_query(claim)
try:
query = urllib.parse.quote(search_query)
url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={query}&format=json"
headers = {"User-Agent": USER_AGENT}
r = requests.get(url, headers=headers, timeout=10)
data = r.json()
results = data.get("query", {}).get("search", [])
print(f"[Wikipedia] Found {len(results)} search results")
saved = 0
for result in results[:3]:
title = result["title"]
page_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{urllib.parse.quote(title)}"
r2 = requests.get(page_url, headers=headers, timeout=5)
if r2.status_code == 200:
extract = r2.json().get("extract", "")
if len(extract) > 20:
claim_emb_wiki = embed_model.encode([claim], normalize_embeddings=True)
relevant, emb = is_relevant(claim_emb_wiki[0], extract, threshold=0.05)
if relevant:
save_evidence(extract, f"Wikipedia: {title}", embedding=emb)
saved += 1
print(f"[Wikipedia] Saved {saved} items.")
except Exception as e:
print("[WARNING] Wikipedia failed:", e)
# ==========================================
# STATIC KNOWLEDGE BASE
# ==========================================
def fetch_knowledge_base(claim, claim_emb, threshold=0.30):
"""Query the curated static knowledge base using embedding similarity.
This is called first so timeless facts always get reliable evidence."""
print("[KnowledgeBase] Querying static knowledge base...")
saved = 0
for entry in KNOWLEDGE_BASE:
text = entry["text"]
source = entry["source"]
emb_text = embed_model.encode([text], normalize_embeddings=True)
sim = float(np.dot(claim_emb, emb_text[0]))
if sim >= threshold:
save_evidence(text, source, embedding=emb_text[0].tolist())
saved += 1
print(f"[KnowledgeBase] Saved {saved} matching entries (threshold={threshold}).")
return saved
# ==========================================
# WIKIDATA ENTITY SEARCH
# ==========================================
def fetch_wikidata(claim, claim_emb, threshold=0.10):
"""Fetch entity summaries from Wikidata's free public API.
No API key required. Good for factual entity-level claims."""
print("[Wikidata] Fetching...")
search_query = get_search_query(claim)
try:
url = "https://www.wikidata.org/w/api.php"
params = {
"action": "wbsearchentities",
"search": search_query,
"language": "en",
"format": "json",
"limit": 5,
"type": "item"
}
headers = {"User-Agent": USER_AGENT}
r = requests.get(url, params=params, headers=headers, timeout=8)
r.raise_for_status()
data = r.json()
results = data.get("search", [])
print(f"[Wikidata] Found {len(results)} entities")
saved = 0
for item in results:
description = item.get("description", "")
label = item.get("label", "")
if description and label:
text = f"{label}: {description}"
relevant, emb = is_relevant(claim_emb, text, threshold=threshold)
if relevant:
save_evidence(text, "Wikidata", embedding=emb)
saved += 1
print(f"[Wikidata] Saved {saved} items.")
return saved
except Exception as e:
print(f"[WARNING] Wikidata failed: {e}")
return 0
# ==========================================
# DUCKDUCKGO FALLBACK
# ==========================================
def fetch_duckduckgo(claim, claim_emb):
print("[Fallback] DuckDuckGo activated...")
search_query = get_search_query(claim)
try:
query = urllib.parse.quote(search_query)
url = f"https://duckduckgo.com/html/?q={query}"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
r = requests.get(url, headers=headers, timeout=10)
soup = BeautifulSoup(r.text, "html.parser")
results = soup.find_all("a", class_="result__a", limit=5)
print(f"[DuckDuckGo] Found {len(results)} results")
saved = 0
for res in results:
text = res.get_text()
if len(text) > 30:
relevant, emb = is_relevant(claim_emb, text, 0.05)
if relevant:
save_evidence(text, "DuckDuckGo", embedding=emb)
saved += 1
print(f"[DuckDuckGo] Saved {saved} items")
except Exception as e:
print("[WARNING] DuckDuckGo failed:", e)
# ==========================================
# BUILD FAISS
# ==========================================
def build_faiss():
"""Loads pre-calculated embeddings from Database and builds index.
No re-encoding performed here — drastically reduces RAM peaks."""
rows = load_all_evidence()
if not rows:
return False
# Filter rows that actually have embeddings
texts = []
embeddings_list = []
for row in rows:
if row[3]: # row[3] is the embedding
texts.append(row[1])
embeddings_list.append(row[3])
if not embeddings_list:
return False
embeddings = np.array(embeddings_list).astype('float32')
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, FAISS_FILE)
return True
# ==========================================
# MAIN PIPELINE (CLI / standalone use)
# ==========================================
def run_fact_check(claim):
print("\n[FACT CHECK]", claim)
init_db()
clear_db()
claim_emb = embed_model.encode([claim], normalize_embeddings=True)
# Fetch from all sources (now includes NewsAPI, consistent with api_wrapper)
fetch_rss(claim_emb)
gdelt_count = fetch_gdelt(claim, claim_emb)
newsapi_count = fetch_newsapi(claim, claim_emb)
fetch_wikipedia(claim)
from project.database import get_total_evidence_count
total_count = get_total_evidence_count()
activate_fallback = (gdelt_count + newsapi_count) == 0 or total_count < 3
if build_faiss():
if os.path.exists(FAISS_FILE):
index = faiss.read_index(FAISS_FILE)
D, _ = index.search(claim_emb, 1)
if len(D) > 0 and len(D[0]) > 0:
similarity = D[0][0]
if similarity < 0.50:
activate_fallback = True
if activate_fallback:
fetch_duckduckgo(claim, claim_emb)
build_faiss()
if not os.path.exists(FAISS_FILE):
print("[ERROR] No evidence found.")
return
index = faiss.read_index(FAISS_FILE)
D, I = index.search(claim_emb, 5)
rows = load_all_evidence()
print("\n[EVIDENCE]")
for idx in I[0]:
if idx < len(rows):
print("-", rows[idx][1][:200])
print("\n[NLI RESULTS]")
for idx in I[0]:
if idx < len(rows):
evidence_text = rows[idx][1]
candidate_labels = [
f"Supports the claim: {claim}",
f"Contradicts the claim: {claim}",
f"Is unrelated to the claim: {claim}"
]
result = nli_model(evidence_text, candidate_labels=candidate_labels)
if result and 'labels' in result:
top_label = result['labels'][0]
top_score = result['scores'][0]
print(f"[{top_label}] (Score: {top_score:.2f})")
else:
print(result)
# ==========================================
# RUN
# ==========================================
if __name__ == "__main__":
claim = input("Enter claim: ").strip()
if claim:
run_fact_check(claim)
else:
print("Claim cannot be empty.")