Commit
·
26212b8
1
Parent(s):
6ff8947
hf interface for mistral
Browse files- app.py +20 -51
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import requests
|
| 4 |
import gradio as gr
|
| 5 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification,
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
import wikipedia
|
|
@@ -11,46 +11,19 @@ import re
|
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 13 |
from tavily import TavilyClient
|
| 14 |
-
|
| 15 |
-
# # Download from Google Drive
|
| 16 |
-
# def download_from_drive(drive_url, dest_path):
|
| 17 |
-
# gdown.download(drive_url, dest_path, quiet=False)
|
| 18 |
-
# return
|
| 19 |
-
|
| 20 |
-
# # Download models
|
| 21 |
-
# TEXT_MODEL_ZIP_URL = "https://drive.google.com/uc?export=download&id=1Sf2DoVaYBqBcdvonf6GJpo_bLWATSgeq"
|
| 22 |
-
# IMAGE_MODEL_URL = "https://drive.google.com/uc?export=download&id=19xRLjNtGWty9loc0_6LPjIYOl-EIf2bm"
|
| 23 |
-
|
| 24 |
-
# os.makedirs("models", exist_ok=True)
|
| 25 |
-
|
| 26 |
-
# # Text model
|
| 27 |
-
# if not os.path.exists("models/text_model"):
|
| 28 |
-
# print("Downloading and extracting text model...")
|
| 29 |
-
# zip_path = "models/text_model.zip"
|
| 30 |
-
# download_from_drive(TEXT_MODEL_ZIP_URL, zip_path)
|
| 31 |
-
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 32 |
-
# zip_ref.extractall("models/text_model")
|
| 33 |
-
# else:
|
| 34 |
-
# print("Text model already exists.")
|
| 35 |
-
|
| 36 |
-
# # Image model
|
| 37 |
-
# if not os.path.exists("models/image_model.pth"):
|
| 38 |
-
# print("Downloading image model...")
|
| 39 |
-
# pth_path = "models/image_model.pth"
|
| 40 |
-
# download_from_drive(IMAGE_MODEL_URL, pth_path)
|
| 41 |
-
# else:
|
| 42 |
-
# print("Image model already exists.")
|
| 43 |
|
| 44 |
text_classifier = None
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
|
| 47 |
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 48 |
explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 49 |
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
|
| 50 |
|
|
|
|
| 51 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
| 52 |
-
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
|
| 53 |
-
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
|
| 54 |
tavily = TavilyClient(api_key=TAVILY_KEY)
|
| 55 |
|
| 56 |
def get_text_classifier():
|
|
@@ -61,14 +34,6 @@ def get_text_classifier():
|
|
| 61 |
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
|
| 62 |
return text_classifier
|
| 63 |
|
| 64 |
-
def get_text_explainer():
|
| 65 |
-
global text_explainer
|
| 66 |
-
if text_explainer is None:
|
| 67 |
-
tokenizer = AutoTokenizer.from_pretrained(explain_model)
|
| 68 |
-
clm = AutoModelForCausalLM.from_pretrained(explain_model)
|
| 69 |
-
text_explainer = pipeline("text-generation", model=clm, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, repetition_penalty=1.2)
|
| 70 |
-
return text_explainer
|
| 71 |
-
|
| 72 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 73 |
if not sentences: return []
|
| 74 |
emb_c = embed_model.encode([claim])
|
|
@@ -130,30 +95,34 @@ def get_evidence_sentences(claim, k=3):
|
|
| 130 |
return (evid or ["Error: No relevant evidence found."])[:k]
|
| 131 |
|
| 132 |
# --- Classification Function ---
|
| 133 |
-
def classify_text(claim
|
| 134 |
-
|
| 135 |
-
text_explainer = get_text_explainer()
|
| 136 |
evidences = get_evidence_sentences(claim)
|
| 137 |
evidence_text = " ".join(evidences)
|
| 138 |
|
| 139 |
# Step 1: FEVER classification
|
| 140 |
text = f"claim: {claim} evidence: {evidence_text}"
|
| 141 |
-
result =
|
| 142 |
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
|
| 143 |
label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
|
| 144 |
|
| 145 |
# Step 2: Mistral explanation generation
|
| 146 |
prompt = f"""
|
| 147 |
-
You are a fact-checking assistant.
|
| 148 |
Claim: {claim}
|
| 149 |
Evidence: {chr(10).join(f"- {e}" for e in evidences)}
|
| 150 |
The model predicts this claim is {label_str}.
|
| 151 |
-
Write a
|
|
|
|
| 152 |
"""
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
# -------------------
|
|
|
|
| 2 |
import torch
|
| 3 |
import requests
|
| 4 |
import gradio as gr
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
import wikipedia
|
|
|
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 13 |
from tavily import TavilyClient
|
| 14 |
+
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
text_classifier = None
|
| 17 |
+
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
|
| 18 |
+
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
|
| 19 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
|
| 21 |
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 22 |
explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 23 |
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
|
| 24 |
|
| 25 |
+
inf_client = InferenceClient(token=HF_TOKEN)
|
| 26 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
|
|
|
|
|
|
| 27 |
tavily = TavilyClient(api_key=TAVILY_KEY)
|
| 28 |
|
| 29 |
def get_text_classifier():
|
|
|
|
| 34 |
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
|
| 35 |
return text_classifier
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 38 |
if not sentences: return []
|
| 39 |
emb_c = embed_model.encode([claim])
|
|
|
|
| 95 |
return (evid or ["Error: No relevant evidence found."])[:k]
|
| 96 |
|
| 97 |
# --- Classification Function ---
|
| 98 |
+
def classify_text(claim):
|
| 99 |
+
classifier = get_text_classifier()
|
|
|
|
| 100 |
evidences = get_evidence_sentences(claim)
|
| 101 |
evidence_text = " ".join(evidences)
|
| 102 |
|
| 103 |
# Step 1: FEVER classification
|
| 104 |
text = f"claim: {claim} evidence: {evidence_text}"
|
| 105 |
+
result = classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
|
| 106 |
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
|
| 107 |
label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
|
| 108 |
|
| 109 |
# Step 2: Mistral explanation generation
|
| 110 |
prompt = f"""
|
| 111 |
+
You are a reliable fact-checking assistant.
|
| 112 |
Claim: {claim}
|
| 113 |
Evidence: {chr(10).join(f"- {e}" for e in evidences)}
|
| 114 |
The model predicts this claim is {label_str}.
|
| 115 |
+
Write a short, clear explanation of why this classification makes sense.
|
| 116 |
+
If the evidence clearly contradicts the label, correct the label in your explanation.
|
| 117 |
"""
|
| 118 |
+
messages = [
|
| 119 |
+
{"role": "system", "content": "You are a reliable fact-checking assistant."},
|
| 120 |
+
{"role": "user", "content": prompt},
|
| 121 |
+
]
|
| 122 |
+
completion = inf_client.chat_completion( model=explain_model, messages=messages, max_tokens=256, temperature=0.3)
|
| 123 |
+
explanation = completion.choices[0].message.content.strip()
|
| 124 |
+
|
| 125 |
+
return f"Prediction: {label_str} + \n\nExplanation:\n{explanation}"
|
| 126 |
|
| 127 |
|
| 128 |
# -------------------
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ wikipedia-api
|
|
| 7 |
wikipedia
|
| 8 |
sentence-transformers
|
| 9 |
scikit-learn
|
| 10 |
-
tavily-python
|
|
|
|
|
|
| 7 |
wikipedia
|
| 8 |
sentence-transformers
|
| 9 |
scikit-learn
|
| 10 |
+
tavily-python
|
| 11 |
+
huggingface-hub
|