rajyalakshmijampani commited on
Commit
8fc155a
·
1 Parent(s): 7455224

include mistral explanation

Browse files
Files changed (1) hide show
  1. app.py +34 -30
app.py CHANGED
@@ -42,7 +42,11 @@ from tavily import TavilyClient
42
  # print("Image model already exists.")
43
 
44
  text_classifier = None
45
- MODEL = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
46
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
47
  TAVILY_KEY = os.getenv("TAVILY_API_KEY")
48
  GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
@@ -52,15 +56,15 @@ def get_text_classifier():
52
  global text_classifier
53
  if text_classifier is None:
54
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
55
- text_tokenizer = AutoTokenizer.from_pretrained("rajyalakshmijampani/fever_finetuned_deberta")
56
- text_model = AutoModelForSequenceClassification.from_pretrained("rajyalakshmijampani/fever_finetuned_deberta")
57
- text_classifier = pipeline("text-classification", model=text_model, tokenizer=text_tokenizer)
58
  return text_classifier
59
 
60
  def _rank_sentences(claim, sentences, top_k=4):
61
  if not sentences: return []
62
- emb_c = MODEL.encode([claim])
63
- emb_s = MODEL.encode(sentences)
64
  sims = cosine_similarity(emb_c, emb_s)[0]
65
  ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)]
66
  return ranked[:top_k]
@@ -119,34 +123,34 @@ def get_evidence_sentences(claim, k=3):
119
 
120
  # --- Classification Function ---
121
  def classify_text(claim, text_classifier):
122
-
123
- evidences = get_evidence_sentences(claim)
124
- print(evidences)
125
- if not evidences or "Error" in evidences[0]:
126
- return f"Prediction: Unknown\n\nTop Evidences:\n{evidences[0]}\n\nExplanation:\nUnable to retrieve reliable evidences."
127
-
128
- # --- Prepare model input ---
129
  evidence_text = " ".join(evidences)
130
- text = f"Claim: {claim}\nEvidence: {evidence_text}"
131
 
132
- # --- Run classifier ---
133
- text_classifier = get_text_classifier()
134
  result = text_classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
135
  top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
136
-
137
- # Map labels according to your model setup
138
- label_str = "REAL" if top_label in ["LABEL_0", "REAL", "SUPPORTED"] else "FAKE"
139
-
140
- explanation = (
141
- f"Based on semantically relevant and filtered evidences, "
142
- f"this claim is **{top_label}**."
143
- )
144
-
145
- return (
146
- f"Prediction: {top_label}\n\n"
147
- f"Top Evidences:\n" + "\n".join(f"- {e}" for e in evidences) +
148
- f"\n\nExplanation:\n{explanation}"
149
- )
 
 
 
 
 
 
150
 
151
 
152
  # -------------------
 
42
  # print("Image model already exists.")
43
 
44
  text_classifier = None
45
+
46
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
47
+ explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
48
+ text_model = "rajyalakshmijampani/fever_finetuned_deberta"
49
+
50
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
51
  TAVILY_KEY = os.getenv("TAVILY_API_KEY")
52
  GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
 
56
  global text_classifier
57
  if text_classifier is None:
58
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
59
+ tokenizer = AutoTokenizer.from_pretrained(text_model)
60
+ seq_clf = AutoModelForSequenceClassification.from_pretrained(text_model)
61
+ text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
62
  return text_classifier
63
 
64
  def _rank_sentences(claim, sentences, top_k=4):
65
  if not sentences: return []
66
+ emb_c = embed_model.encode([claim])
67
+ emb_s = embed_model.encode(sentences)
68
  sims = cosine_similarity(emb_c, emb_s)[0]
69
  ranked = [s for s, _ in sorted(zip(sentences, sims), key=lambda x: x[1], reverse=True)]
70
  return ranked[:top_k]
 
123
 
124
  # --- Classification Function ---
125
  def classify_text(claim, text_classifier):
126
+ text_classifier = get_text_classifier()
127
+ evidences = get_evidence_sentences(claim)
 
 
 
 
 
128
  evidence_text = " ".join(evidences)
 
129
 
130
+ # Step 1: FEVER classification
131
+ text = f"claim: {claim} evidence: {evidence_text}"
132
  result = text_classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
133
  top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
134
+ label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
135
+
136
+ # Step 2: Mistral explanation generation
137
+ explain_pipe = pipeline("text-generation", model=explain_model, tokenizer=explain_model,
138
+ max_new_tokens=150, temperature=0.5, repetition_penalty=1.2)
139
+ prompt = f"""
140
+ You are a fact-checking assistant.
141
+ Claim: {claim}
142
+ Evidence:
143
+ {chr(10).join(f"- {e}" for e in evidences)}
144
+ The model predicts this claim is {label_str}.
145
+ Write a clear, human-readable explanation of why this classification makes sense, correcting the label if the evidence clearly contradicts it.
146
+ """
147
+
148
+ expl = explain_pipe(prompt)[0]["generated_text"].split("Evidence:")[-1].strip()
149
+
150
+ return f"Prediction: {label_str} \n\n \
151
+ Top Evidences:\n" + \
152
+ "\n".join(f"- {e}" for e in evidences) + \
153
+ f"\n\nExplanation:\n{expl}"
154
 
155
 
156
  # -------------------