rajyalakshmijampani commited on
Commit
26212b8
·
1 Parent(s): 6ff8947

hf interface for mistral

Browse files
Files changed (2) hide show
  1. app.py +20 -51
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  import requests
4
  import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, pipeline
6
  from PIL import Image
7
  from io import BytesIO
8
  import wikipedia
@@ -11,46 +11,19 @@ import re
11
  from sentence_transformers import SentenceTransformer
12
  from sklearn.metrics.pairwise import cosine_similarity
13
  from tavily import TavilyClient
14
-
15
- # # Download from Google Drive
16
- # def download_from_drive(drive_url, dest_path):
17
- # gdown.download(drive_url, dest_path, quiet=False)
18
- # return
19
-
20
- # # Download models
21
- # TEXT_MODEL_ZIP_URL = "https://drive.google.com/uc?export=download&id=1Sf2DoVaYBqBcdvonf6GJpo_bLWATSgeq"
22
- # IMAGE_MODEL_URL = "https://drive.google.com/uc?export=download&id=19xRLjNtGWty9loc0_6LPjIYOl-EIf2bm"
23
-
24
- # os.makedirs("models", exist_ok=True)
25
-
26
- # # Text model
27
- # if not os.path.exists("models/text_model"):
28
- # print("Downloading and extracting text model...")
29
- # zip_path = "models/text_model.zip"
30
- # download_from_drive(TEXT_MODEL_ZIP_URL, zip_path)
31
- # with zipfile.ZipFile(zip_path, 'r') as zip_ref:
32
- # zip_ref.extractall("models/text_model")
33
- # else:
34
- # print("Text model already exists.")
35
-
36
- # # Image model
37
- # if not os.path.exists("models/image_model.pth"):
38
- # print("Downloading image model...")
39
- # pth_path = "models/image_model.pth"
40
- # download_from_drive(IMAGE_MODEL_URL, pth_path)
41
- # else:
42
- # print("Image model already exists.")
43
 
44
  text_classifier = None
45
- text_explainer = None
 
 
46
 
47
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
48
  explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
49
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
50
 
 
51
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
52
- TAVILY_KEY = os.getenv("TAVILY_API_KEY")
53
- GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
54
  tavily = TavilyClient(api_key=TAVILY_KEY)
55
 
56
  def get_text_classifier():
@@ -61,14 +34,6 @@ def get_text_classifier():
61
  text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
62
  return text_classifier
63
 
64
- def get_text_explainer():
65
- global text_explainer
66
- if text_explainer is None:
67
- tokenizer = AutoTokenizer.from_pretrained(explain_model)
68
- clm = AutoModelForCausalLM.from_pretrained(explain_model)
69
- text_explainer = pipeline("text-generation", model=clm, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, repetition_penalty=1.2)
70
- return text_explainer
71
-
72
  def _rank_sentences(claim, sentences, top_k=4):
73
  if not sentences: return []
74
  emb_c = embed_model.encode([claim])
@@ -130,30 +95,34 @@ def get_evidence_sentences(claim, k=3):
130
  return (evid or ["Error: No relevant evidence found."])[:k]
131
 
132
  # --- Classification Function ---
133
- def classify_text(claim, text_classifier):
134
- text_classifier = get_text_classifier()
135
- text_explainer = get_text_explainer()
136
  evidences = get_evidence_sentences(claim)
137
  evidence_text = " ".join(evidences)
138
 
139
  # Step 1: FEVER classification
140
  text = f"claim: {claim} evidence: {evidence_text}"
141
- result = text_classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
142
  top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
143
  label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
144
 
145
  # Step 2: Mistral explanation generation
146
  prompt = f"""
147
- You are a fact-checking assistant.
148
  Claim: {claim}
149
  Evidence: {chr(10).join(f"- {e}" for e in evidences)}
150
  The model predicts this claim is {label_str}.
151
- Write a clear, human-readable explanation of why this classification makes sense, correcting the label if the evidence clearly contradicts it.
 
152
  """
153
-
154
- expl = text_explainer(prompt)[0]["generated_text"]
155
-
156
- return f"Prediction: {label_str} + \n\nExplanation:\n{expl}"
 
 
 
 
157
 
158
 
159
  # -------------------
 
2
  import torch
3
  import requests
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
6
  from PIL import Image
7
  from io import BytesIO
8
  import wikipedia
 
11
  from sentence_transformers import SentenceTransformer
12
  from sklearn.metrics.pairwise import cosine_similarity
13
  from tavily import TavilyClient
14
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  text_classifier = None
17
+ TAVILY_KEY = os.getenv("TAVILY_API_KEY")
18
+ GOOGLE_KEY = os.getenv("GOOGLE_FC_KEY")
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
 
21
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
22
  explain_model = "mistralai/Mistral-7B-Instruct-v0.2"
23
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
24
 
25
+ inf_client = InferenceClient(token=HF_TOKEN)
26
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
 
 
27
  tavily = TavilyClient(api_key=TAVILY_KEY)
28
 
29
  def get_text_classifier():
 
34
  text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
35
  return text_classifier
36
 
 
 
 
 
 
 
 
 
37
  def _rank_sentences(claim, sentences, top_k=4):
38
  if not sentences: return []
39
  emb_c = embed_model.encode([claim])
 
95
  return (evid or ["Error: No relevant evidence found."])[:k]
96
 
97
  # --- Classification Function ---
98
+ def classify_text(claim):
99
+ classifier = get_text_classifier()
 
100
  evidences = get_evidence_sentences(claim)
101
  evidence_text = " ".join(evidences)
102
 
103
  # Step 1: FEVER classification
104
  text = f"claim: {claim} evidence: {evidence_text}"
105
+ result = classifier(text, truncation=True, max_length=512, return_all_scores=True)[0]
106
  top_label = sorted(result, key=lambda x: x["score"], reverse=True)[0]["label"]
107
  label_str = "REAL" if top_label == "LABEL_0" else "FAKE"
108
 
109
  # Step 2: Mistral explanation generation
110
  prompt = f"""
111
+ You are a reliable fact-checking assistant.
112
  Claim: {claim}
113
  Evidence: {chr(10).join(f"- {e}" for e in evidences)}
114
  The model predicts this claim is {label_str}.
115
+ Write a short, clear explanation of why this classification makes sense.
116
+ If the evidence clearly contradicts the label, correct the label in your explanation.
117
  """
118
+ messages = [
119
+ {"role": "system", "content": "You are a reliable fact-checking assistant."},
120
+ {"role": "user", "content": prompt},
121
+ ]
122
+ completion = inf_client.chat_completion( model=explain_model, messages=messages, max_tokens=256, temperature=0.3)
123
+ explanation = completion.choices[0].message.content.strip()
124
+
125
+ return f"Prediction: {label_str} + \n\nExplanation:\n{explanation}"
126
 
127
 
128
  # -------------------
requirements.txt CHANGED
@@ -7,4 +7,5 @@ wikipedia-api
7
  wikipedia
8
  sentence-transformers
9
  scikit-learn
10
- tavily-python
 
 
7
  wikipedia
8
  sentence-transformers
9
  scikit-learn
10
+ tavily-python
11
+ huggingface-hub