Commit
·
8d0a810
1
Parent(s):
b4e73b5
image classifier update
Browse files- app.py +62 -26
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
|
|
|
| 3 |
import json
|
| 4 |
import requests
|
| 5 |
import gradio as gr
|
| 6 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 7 |
-
from
|
| 8 |
-
from io import BytesIO
|
| 9 |
import wikipedia
|
| 10 |
import wikipediaapi
|
| 11 |
import re
|
|
@@ -14,7 +14,24 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 14 |
from tavily import TavilyClient
|
| 15 |
from huggingface_hub import InferenceClient
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
text_classifier = None
|
|
|
|
| 18 |
TAVILY_KEY = None
|
| 19 |
GOOGLE_KEY = None
|
| 20 |
HF_TOKEN = None
|
|
@@ -23,6 +40,7 @@ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
| 23 |
explain_model = "meta-llama/Llama-3.1-8B-Instruct"
|
| 24 |
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
|
| 25 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
|
|
|
| 26 |
|
| 27 |
def get_text_classifier():
|
| 28 |
global text_classifier
|
|
@@ -32,6 +50,29 @@ def get_text_classifier():
|
|
| 32 |
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
|
| 33 |
return text_classifier
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 36 |
if not sentences: return []
|
| 37 |
emb_c = embed_model.encode([claim])
|
|
@@ -115,7 +156,7 @@ def get_evidence_sentences(claim, k=3):
|
|
| 115 |
evid = [e for e in evid if len(e.strip()) > 10]
|
| 116 |
return (evid or ["Error: No relevant evidence found."])[:k]
|
| 117 |
|
| 118 |
-
# --- Classification Function ---
|
| 119 |
def classify_text(claim, hf_token, tavily_key, google_key):
|
| 120 |
|
| 121 |
global HF_TOKEN, TAVILY_KEY, GOOGLE_KEY
|
|
@@ -184,24 +225,19 @@ def classify_text(claim, hf_token, tavily_key, google_key):
|
|
| 184 |
return formatted_output.strip()
|
| 185 |
|
| 186 |
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
output = image_model(img_tensor)
|
| 201 |
-
preds = torch.softmax(output, dim=1)
|
| 202 |
-
label = torch.argmax(preds).item()
|
| 203 |
-
label_str = "REAL" if label == 1 else "FAKE"
|
| 204 |
-
return f"Prediction: {label_str}\n\nExplanation: The image model classifies this as {label_str.lower()} based on learned patterns."
|
| 205 |
|
| 206 |
# -------------------
|
| 207 |
# UI Layout (Gradio)
|
|
@@ -219,9 +255,9 @@ with gr.Blocks() as demo:
|
|
| 219 |
|
| 220 |
with gr.Column(scale=1): # Right half — user token inputs
|
| 221 |
gr.Markdown("## Enter your API keys")
|
| 222 |
-
hf_token = gr.Textbox(label="Hugging Face Token", type="password"
|
| 223 |
-
tavily_key = gr.Textbox(label="Tavily API Key", type="password"
|
| 224 |
-
google_key = gr.Textbox(label="Google Fact Check API Key", type="password"
|
| 225 |
|
| 226 |
# Enable button when all fields filled
|
| 227 |
def enable_button(hf, tavily, google):
|
|
@@ -240,7 +276,7 @@ with gr.Blocks() as demo:
|
|
| 240 |
|
| 241 |
with gr.Tab("Image Detector"):
|
| 242 |
img_input = gr.Image(type="pil", label="Upload Image")
|
| 243 |
-
img_output = gr.
|
| 244 |
img_button = gr.Button("Classify Image")
|
| 245 |
img_button.click(classify_image, inputs=img_input, outputs=img_output)
|
| 246 |
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
import json
|
| 5 |
import requests
|
| 6 |
import gradio as gr
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, CLIPProcessor, CLIPModel
|
| 8 |
+
from collections import OrderedDict
|
|
|
|
| 9 |
import wikipedia
|
| 10 |
import wikipediaapi
|
| 11 |
import re
|
|
|
|
| 14 |
from tavily import TavilyClient
|
| 15 |
from huggingface_hub import InferenceClient
|
| 16 |
|
| 17 |
+
class CLIPImageClassifier(nn.Module):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 21 |
+
self.classifier = nn.Sequential(
|
| 22 |
+
nn.Linear(self.clip.config.vision_config.hidden_size, 256),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.Dropout(0.5),
|
| 25 |
+
nn.Linear(256, 1),
|
| 26 |
+
nn.Sigmoid()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, pixel_values):
|
| 30 |
+
feats = self.clip.vision_model(pixel_values=pixel_values).pooler_output
|
| 31 |
+
return self.classifier(feats)
|
| 32 |
+
|
| 33 |
text_classifier = None
|
| 34 |
+
image_classifier = None
|
| 35 |
TAVILY_KEY = None
|
| 36 |
GOOGLE_KEY = None
|
| 37 |
HF_TOKEN = None
|
|
|
|
| 40 |
explain_model = "meta-llama/Llama-3.1-8B-Instruct"
|
| 41 |
text_model = "rajyalakshmijampani/fever_finetuned_deberta"
|
| 42 |
wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
|
| 43 |
+
image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 44 |
|
| 45 |
def get_text_classifier():
|
| 46 |
global text_classifier
|
|
|
|
| 50 |
text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
|
| 51 |
return text_classifier
|
| 52 |
|
| 53 |
+
def get_image_classifier():
|
| 54 |
+
global image_classifier
|
| 55 |
+
if image_classifier is None:
|
| 56 |
+
url = "https://huggingface.co/rajyalakshmijampani/finetuned_clip/resolve/main/best_clip_finetuned_classifier.pth"
|
| 57 |
+
path = "best_clip_finetuned_classifier.pth"
|
| 58 |
+
|
| 59 |
+
if not os.path.exists(path):
|
| 60 |
+
r = requests.get(url)
|
| 61 |
+
with open(path, "wb") as f:
|
| 62 |
+
f.write(r.content)
|
| 63 |
+
|
| 64 |
+
image_classifier = CLIPImageClassifier()
|
| 65 |
+
state = torch.load(path, map_location="cpu")
|
| 66 |
+
clean_state = OrderedDict(
|
| 67 |
+
(k[7:], v) if k.startswith("module.") else (k, v)
|
| 68 |
+
for k, v in state.items()
|
| 69 |
+
)
|
| 70 |
+
image_classifier.load_state_dict(clean_state, strict=False)
|
| 71 |
+
image_classifier.eval()
|
| 72 |
+
return image_classifier
|
| 73 |
+
|
| 74 |
+
return image_classifier
|
| 75 |
+
|
| 76 |
def _rank_sentences(claim, sentences, top_k=4):
|
| 77 |
if not sentences: return []
|
| 78 |
emb_c = embed_model.encode([claim])
|
|
|
|
| 156 |
evid = [e for e in evid if len(e.strip()) > 10]
|
| 157 |
return (evid or ["Error: No relevant evidence found."])[:k]
|
| 158 |
|
| 159 |
+
# ---Text Classification Function ---
|
| 160 |
def classify_text(claim, hf_token, tavily_key, google_key):
|
| 161 |
|
| 162 |
global HF_TOKEN, TAVILY_KEY, GOOGLE_KEY
|
|
|
|
| 225 |
return formatted_output.strip()
|
| 226 |
|
| 227 |
|
| 228 |
+
# ---- Image classification Function ----
|
| 229 |
+
def classify_image(image):
|
| 230 |
+
global image_processor
|
| 231 |
+
classifier = get_image_classifier()
|
| 232 |
+
try:
|
| 233 |
+
inputs = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"]
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
output = classifier(inputs)
|
| 236 |
+
p = output.item()
|
| 237 |
+
label = "Fake" if p > 0.5 else "Real"
|
| 238 |
+
return f"**Prediction:** {label}\n**Confidence score:** {p:.2f}"
|
| 239 |
+
except Exception as e:
|
| 240 |
+
return f"Error: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
# -------------------
|
| 243 |
# UI Layout (Gradio)
|
|
|
|
| 255 |
|
| 256 |
with gr.Column(scale=1): # Right half — user token inputs
|
| 257 |
gr.Markdown("## Enter your API keys")
|
| 258 |
+
hf_token = gr.Textbox(label="Hugging Face Token 🔴", type="password")
|
| 259 |
+
tavily_key = gr.Textbox(label="Tavily API Key 🔴", type="password")
|
| 260 |
+
google_key = gr.Textbox(label="Google Fact Check API Key 🔴", type="password")
|
| 261 |
|
| 262 |
# Enable button when all fields filled
|
| 263 |
def enable_button(hf, tavily, google):
|
|
|
|
| 276 |
|
| 277 |
with gr.Tab("Image Detector"):
|
| 278 |
img_input = gr.Image(type="pil", label="Upload Image")
|
| 279 |
+
img_output = gr.Markdown(label="Model Output", value="Results will appear here...")
|
| 280 |
img_button = gr.Button("Classify Image")
|
| 281 |
img_button.click(classify_image, inputs=img_input, outputs=img_output)
|
| 282 |
|
requirements.txt
CHANGED
|
@@ -2,7 +2,6 @@ gradio
|
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
requests
|
| 5 |
-
Pillow
|
| 6 |
wikipedia-api
|
| 7 |
wikipedia
|
| 8 |
sentence-transformers
|
|
|
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
requests
|
|
|
|
| 5 |
wikipedia-api
|
| 6 |
wikipedia
|
| 7 |
sentence-transformers
|