|
|
import os |
|
|
import torch |
|
|
import requests |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import wikipedia |
|
|
import wikipediaapi |
|
|
import re |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from tavily import TavilyClient |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
text_classifier = None |
|
|
TAVILY_KEY = os.getenv("TAVILY_API_KEY") |
|
|
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
embed_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
explain_model = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
text_model = "rajyalakshmijampani/fever_finetuned_deberta" |
|
|
|
|
|
inf_client = InferenceClient(token=HF_TOKEN) |
|
|
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0') |
|
|
tavily = TavilyClient(api_key=TAVILY_KEY) |
|
|
|
|
|
def get_text_classifier(): |
|
|
global text_classifier |
|
|
if text_classifier is None: |
|
|
tokenizer = AutoTokenizer.from_pretrained(text_model) |
|
|
seq_clf = AutoModelForSequenceClassification.from_pretrained(text_model) |
|
|
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer) |
|
|
return text_classifier |
|
|
|
|
|
def _rank_sentences(claim, sentences, top_k=4): |
|
|
if not sentences: return [] |
|
|
emb_c = embed_model.encode([claim]) |
|
|
emb_s = embed_model.encode(sentences) |
|
|
sims = cosine_similarity(emb_c, emb_s)[0] |
|
|
ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)] |
|
|
return ranked[:top_k] |
|
|
|
|
|
def _split_sentences(text): |
|
|
return [s.strip() for s in re.split(r'(?<=[.!?]) +', text) if 25 < len(s) < 250] |
|
|
|
|
|
def _from_google(claim): |
|
|
if not GOOGLE_KEY: return [] |
|
|
url = "https://factchecktools.googleapis.com/v1alpha1/claims:search" |
|
|
r = requests.get(url, params={"query": claim, "key": GOOGLE_KEY, "pageSize": 2}).json() |
|
|
claims = r.get("claims", []) |
|
|
evid = [] |
|
|
for c in claims: |
|
|
rev = c.get("claimReview", []) |
|
|
if rev: |
|
|
rating = rev[0].get("textualRating", "") |
|
|
site = rev[0].get("publisher", {}).get("name", "") |
|
|
evid.append(f"{site} rated this claim as {rating}.") |
|
|
return evid[:3] |
|
|
|
|
|
def _from_tavily(claim): |
|
|
if not TAVILY_KEY: return [] |
|
|
try: |
|
|
results = tavily.search(claim).get("results", []) |
|
|
sents = [] |
|
|
for r in results: |
|
|
for s in _split_sentences(r.get("content", "")): |
|
|
if not any(x in s.lower() for x in ["video game", "film", "fiction"]): |
|
|
sents.append(s) |
|
|
return _rank_sentences(claim, sents, 4) |
|
|
except Exception: |
|
|
return [] |
|
|
|
|
|
def _from_wiki(claim): |
|
|
try: |
|
|
titles = wikipedia.search(claim, results=3) |
|
|
sents = [] |
|
|
for t in titles: |
|
|
page = wiki.page(t) |
|
|
if not page.exists(): continue |
|
|
for s in _split_sentences(page.text[:2000]): |
|
|
if not any(x in s.lower() for x in ["video game", "fiction", "film"]): |
|
|
sents.append(s) |
|
|
return _rank_sentences(claim, sents, 4) |
|
|
except Exception: |
|
|
return [] |
|
|
|
|
|
def get_evidence_sentences(claim, k=3): |
|
|
evid = _from_google(claim) |
|
|
if len(evid) >= k: return evid[:k] |
|
|
evid += _from_tavily(claim) |
|
|
if len(evid) >= k: return evid[:k] |
|
|
evid += _from_wiki(claim) |
|
|
return (evid or ["Error: No relevant evidence found."])[:k] |
|
|
|
|
|
|
|
|
def classify_text(claim): |
|
|
classifier = get_text_classifier() |
|
|
evidences = get_evidence_sentences(claim) |
|
|
evidence_text = " ".join(evidences) |
|
|
|
|
|
|
|
|
text = f"claim: {claim} evidence: {evidence_text}" |
|
|
result = classifier(text, truncation=True, max_length=512, return_all_scores=True)[0] |
|
|
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"] |
|
|
label_str = "REAL" if top_label == "LABEL_0" else "FAKE" |
|
|
|
|
|
|
|
|
prompt = f""" |
|
|
You are a reliable fact-checking assistant. |
|
|
Claim: {claim} |
|
|
Evidence: {chr(10).join(f"- {e}" for e in evidences)} |
|
|
The model predicts this claim is {label_str}. |
|
|
Write a short, clear explanation of why this classification makes sense. |
|
|
If the evidence clearly contradicts the label, correct the label in your explanation. |
|
|
""" |
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a reliable fact-checking assistant."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
completion = inf_client.chat_completion( model=explain_model, messages=messages, max_tokens=256, temperature=0.3) |
|
|
explanation = completion.choices[0].message.content.strip() |
|
|
|
|
|
return f"Prediction: {label_str} + \n\nExplanation:\n{explanation}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_image(img): |
|
|
if img is None: |
|
|
return "Please upload an image." |
|
|
transform = torch.nn.Sequential( |
|
|
torch.nn.Identity() |
|
|
) |
|
|
img_tensor = torch.tensor( |
|
|
[list(img.resize((224, 224)).getdata())], dtype=torch.float32 |
|
|
).view(1, 224, 224, 3).permute(0, 3, 1, 2) / 255.0 |
|
|
with torch.no_grad(): |
|
|
output = image_model(img_tensor) |
|
|
preds = torch.softmax(output, dim=1) |
|
|
label = torch.argmax(preds).item() |
|
|
label_str = "REAL" if label == 1 else "FAKE" |
|
|
return f"Prediction: {label_str}\n\nExplanation: The image model classifies this as {label_str.lower()} based on learned patterns." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Multimodal Misinformation Detector") |
|
|
|
|
|
with gr.Tab("Text Detector"): |
|
|
claim = gr.Textbox(label="Enter Claim") |
|
|
text_button = gr.Button("Classify Claim") |
|
|
text_output = gr.Textbox(label="Model Output", lines=8) |
|
|
text_button.click(classify_text, inputs=claim, outputs=text_output) |
|
|
|
|
|
|
|
|
with gr.Tab("Image Detector"): |
|
|
img_input = gr.Image(type="pil", label="Upload Image") |
|
|
img_output = gr.Textbox(label="Model Output", lines=6) |
|
|
img_button = gr.Button("Classify Image") |
|
|
img_button.click(classify_image, inputs=img_input, outputs=img_output) |
|
|
|
|
|
demo.launch() |
|
|
|