Commit
·
8fc155a
1
Parent(s):
7455224
include mistral explanation
Browse files
app.py
CHANGED
|
@@ -42,7 +42,11 @@ from tavily import TavilyClient
|
|
| 42 |
# print("Image model already exists.")
|
| 43 |
|
| 44 |
text_classifier = None
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
| 47 |
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
|
| 48 |
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
|
|
@@ -52,15 +56,15 @@ def get_text_classifier():
|
|
| 52 |
global text_classifier
|
| 53 |
if text_classifier is None:
|
| 54 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
text_classifier = pipeline("text-classification", model=
|
| 58 |
return text_classifier
|
| 59 |
|
| 60 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 61 |
if not sentences: return []
|
| 62 |
-
emb_c =
|
| 63 |
-
emb_s =
|
| 64 |
sims = cosine_similarity(emb_c, emb_s)[0]
|
| 65 |
ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)]
|
| 66 |
return ranked[:top_k]
|
|
@@ -119,34 +123,34 @@ def get_evidence_sentences(claim, k=3):
|
|
| 119 |
|
| 120 |
# --- Classification Function ---
|
| 121 |
def classify_text(claim, text_classifier):
|
| 122 |
-
|
| 123 |
-
evidences = get_evidence_sentences(claim)
|
| 124 |
-
print(evidences)
|
| 125 |
-
if not evidences or "Error" in evidences[0]:
|
| 126 |
-
return f"Prediction: Unknown\n\nTop Evidences:\n{evidences[0]}\n\nExplanation:\nUnable to retrieve reliable evidences."
|
| 127 |
-
|
| 128 |
-
# --- Prepare model input ---
|
| 129 |
evidence_text = " ".join(evidences)
|
| 130 |
-
text = f"Claim: {claim}\nEvidence: {evidence_text}"
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
|
| 134 |
result = text_classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
|
| 135 |
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
# -------------------
|
|
|
|
| 42 |
# print("Image model already exists.")
|
| 43 |
|
| 44 |
text_classifier = None
|
| 45 |
+
|
| 46 |
+
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 47 |
+
explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 48 |
+
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
|
| 49 |
+
|
| 50 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
| 51 |
TAVILY_KEY = os.getenv("TAVILY_API_KEY")
|
| 52 |
GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
|
|
|
|
| 56 |
global text_classifier
|
| 57 |
if text_classifier is None:
|
| 58 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(text_model)
|
| 60 |
+
seq_clf = AutoModelForSequenceClassification.from_pretrained(text_model)
|
| 61 |
+
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
|
| 62 |
return text_classifier
|
| 63 |
|
| 64 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 65 |
if not sentences: return []
|
| 66 |
+
emb_c = embed_model.encode([claim])
|
| 67 |
+
emb_s = embed_model.encode(sentences)
|
| 68 |
sims = cosine_similarity(emb_c, emb_s)[0]
|
| 69 |
ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)]
|
| 70 |
return ranked[:top_k]
|
|
|
|
| 123 |
|
| 124 |
# --- Classification Function ---
|
| 125 |
def classify_text(claim, text_classifier):
|
| 126 |
+
text_classifier = get_text_classifier()
|
| 127 |
+
evidences = get_evidence_sentences(claim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
evidence_text = " ".join(evidences)
|
|
|
|
| 129 |
|
| 130 |
+
# Step 1: FEVER classification
|
| 131 |
+
text = f"claim: {claim} evidence: {evidence_text}"
|
| 132 |
result = text_classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
|
| 133 |
top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
|
| 134 |
+
label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
|
| 135 |
+
|
| 136 |
+
# Step 2: Mistral explanation generation
|
| 137 |
+
explain_pipe = pipeline("text-generation", model=explain_model, tokenizer=explain_model,
|
| 138 |
+
max_new_tokens=150, temperature=0.5, repetition_penalty=1.2)
|
| 139 |
+
prompt = f"""
|
| 140 |
+
You are a fact-checking assistant.
|
| 141 |
+
Claim: {claim}
|
| 142 |
+
Evidence:
|
| 143 |
+
{chr(10).join(f"- {e}" for e in evidences)}
|
| 144 |
+
The model predicts this claim is {label_str}.
|
| 145 |
+
Write a clear, human-readable explanation of why this classification makes sense, correcting the label if the evidence clearly contradicts it.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
expl = explain_pipe(prompt)[0]["generated_text"].split("Evidence:")[-1].strip()
|
| 149 |
+
|
| 150 |
+
return f"Prediction: {label_str} \n\n \
|
| 151 |
+
Top Evidences:\n" + \
|
| 152 |
+
"\n".join(f"- {e}" for e in evidences) + \
|
| 153 |
+
f"\n\nExplanation:\n{expl}"
|
| 154 |
|
| 155 |
|
| 156 |
# -------------------
|