rajyalakshmijampani's picture
hf interface for mistral
26212b8
raw
history blame
6.38 kB
import os
import torch
import requests
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from PIL import Image
from io import BytesIO
import wikipedia
import wikipediaapi
import re
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tavily import TavilyClient
from huggingface_hub import InferenceClient
text_classifier = None
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
inf_client = InferenceClient(token=HF_TOKEN)
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
tavily = TavilyClient(api_key=TAVILY_KEY)
def get_text_classifier():
global text_classifier
if text_classifier is None:
tokenizer = AutoTokenizer.from_pretrained(text_model)
seq_clf = AutoModelForSequenceClassification.from_pretrained(text_model)
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
return text_classifier
def _rank_sentences(claim, sentences, top_k=4):
if not sentences: return []
emb_c = embed_model.encode([claim])
emb_s = embed_model.encode(sentences)
sims = cosine_similarity(emb_c, emb_s)[0]
ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)]
return ranked[:top_k]
def _split_sentences(text):
return [s.strip() for s in re.split(r'(?<=[.!?]) +', text) if 25 < len(s) < 250]
def _from_google(claim):
if not GOOGLE_KEY: return []
url = "https://factchecktools.googleapis.com/v1alpha1/claims:search"
r = requests.get(url, params={"query": claim, "key": GOOGLE_KEY, "pageSize": 2}).json()
claims = r.get("claims", [])
evid = []
for c in claims:
rev = c.get("claimReview", [])
if rev:
rating = rev[0].get("textualRating", "")
site = rev[0].get("publisher", {}).get("name", "")
evid.append(f"{site} rated this claim as {rating}.")
return evid[:3]
def _from_tavily(claim):
if not TAVILY_KEY: return []
try:
results = tavily.search(claim).get("results", [])
sents = []
for r in results:
for s in _split_sentences(r.get("content", "")):
if not any(x in s.lower() for x in ["video game", "film", "fiction"]):
sents.append(s)
return _rank_sentences(claim, sents, 4)
except Exception:
return []
def _from_wiki(claim):
try:
titles = wikipedia.search(claim, results=3)
sents = []
for t in titles:
page = wiki.page(t)
if not page.exists(): continue
for s in _split_sentences(page.text[:2000]):
if not any(x in s.lower() for x in ["video game", "fiction", "film"]):
sents.append(s)
return _rank_sentences(claim, sents, 4)
except Exception:
return []
def get_evidence_sentences(claim, k=3):
evid = _from_google(claim)
if len(evid) >= k: return evid[:k]
evid += _from_tavily(claim)
if len(evid) >= k: return evid[:k]
evid += _from_wiki(claim)
return (evid or ["Error: No relevant evidence found."])[:k]
# --- Classification Function ---
def classify_text(claim):
classifier = get_text_classifier()
evidences = get_evidence_sentences(claim)
evidence_text = " ".join(evidences)
# Step 1: FEVER classification
text = f"claim: {claim} evidence: {evidence_text}"
result = classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
# Step 2: Mistral explanation generation
prompt = f"""
You are a reliable fact-checking assistant.
Claim: {claim}
Evidence: {chr(10).join(f"- {e}" for e in evidences)}
The model predicts this claim is {label_str}.
Write a short, clear explanation of why this classification makes sense.
If the evidence clearly contradicts the label, correct the label in your explanation.
"""
messages = [
{"role": "system", "content": "You are a reliable fact-checking assistant."},
{"role": "user", "content": prompt},
]
completion = inf_client.chat_completion( model=explain_model, messages=messages, max_tokens=256, temperature=0.3)
explanation = completion.choices[0].message.content.strip()
return f"Prediction: {label_str} + \n\nExplanation:\n{explanation}"
# -------------------
# Image classification
# -------------------
def classify_image(img):
if img is None:
return "Please upload an image."
transform = torch.nn.Sequential(
torch.nn.Identity() # 👈 replace with actual transforms if needed
)
img_tensor = torch.tensor(
[list(img.resize((224, 224)).getdata())], dtype=torch.float32
).view(1, 224, 224, 3).permute(0, 3, 1, 2) / 255.0
with torch.no_grad():
output = image_model(img_tensor)
preds = torch.softmax(output, dim=1)
label = torch.argmax(preds).item()
label_str = "REAL" if label == 1 else "FAKE"
return f"Prediction: {label_str}\n\nExplanation: The image model classifies this as {label_str.lower()} based on learned patterns."
# -------------------
# UI Layout (Gradio)
# -------------------
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Misinformation Detector")
with gr.Tab("Text Detector"):
claim = gr.Textbox(label="Enter Claim")
text_button = gr.Button("Classify Claim")
text_output = gr.Textbox(label="Model Output", lines=8)
text_button.click(classify_text, inputs=claim, outputs=text_output)
with gr.Tab("Image Detector"):
img_input = gr.Image(type="pil", label="Upload Image")
img_output = gr.Textbox(label="Model Output", lines=6)
img_button = gr.Button("Classify Image")
img_button.click(classify_image, inputs=img_input, outputs=img_output)
demo.launch()