rajyalakshmijampani commited on
Commit
8d0a810
·
1 Parent(s): b4e73b5

image classifier update

Browse files
Files changed (2) hide show
  1. app.py +62 -26
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
2
  import torch
 
3
  import json
4
  import requests
5
  import gradio as gr
6
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
7
- from PIL import Image
8
- from io import BytesIO
9
  import wikipedia
10
  import wikipediaapi
11
  import re
@@ -14,7 +14,24 @@ from sklearn.metrics.pairwise import cosine_similarity
14
  from tavily import TavilyClient
15
  from huggingface_hub import InferenceClient
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  text_classifier = None
 
18
  TAVILY_KEY = None
19
  GOOGLE_KEY = None
20
  HF_TOKEN = None
@@ -23,6 +40,7 @@ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
23
  explain_model = "meta-llama/Llama-3.1-8B-Instruct"
24
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
25
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
 
26
 
27
  def get_text_classifier():
28
  global text_classifier
@@ -32,6 +50,29 @@ def get_text_classifier():
32
  text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
33
  return text_classifier
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def _rank_sentences(claim, sentences, top_k=4):
36
  if not sentences: return []
37
  emb_c = embed_model.encode([claim])
@@ -115,7 +156,7 @@ def get_evidence_sentences(claim, k=3):
115
  evid = [e for e in evid if len(e.strip()) > 10]
116
  return (evid or ["Error: No relevant evidence found."])[:k]
117
 
118
- # --- Classification Function ---
119
  def classify_text(claim, hf_token, tavily_key, google_key):
120
 
121
  global HF_TOKEN, TAVILY_KEY, GOOGLE_KEY
@@ -184,24 +225,19 @@ def classify_text(claim, hf_token, tavily_key, google_key):
184
  return formatted_output.strip()
185
 
186
 
187
- # -------------------
188
- # Image classification
189
- # -------------------
190
- def classify_image(img):
191
- if img is None:
192
- return "Please upload an image."
193
- transform = torch.nn.Sequential(
194
- torch.nn.Identity() # 👈 replace with actual transforms if needed
195
- )
196
- img_tensor = torch.tensor(
197
- [list(img.resize((224, 224)).getdata())], dtype=torch.float32
198
- ).view(1, 224, 224, 3).permute(0, 3, 1, 2) / 255.0
199
- with torch.no_grad():
200
- output = image_model(img_tensor)
201
- preds = torch.softmax(output, dim=1)
202
- label = torch.argmax(preds).item()
203
- label_str = "REAL" if label == 1 else "FAKE"
204
- return f"Prediction: {label_str}\n\nExplanation: The image model classifies this as {label_str.lower()} based on learned patterns."
205
 
206
  # -------------------
207
  # UI Layout (Gradio)
@@ -219,9 +255,9 @@ with gr.Blocks() as demo:
219
 
220
  with gr.Column(scale=1): # Right half — user token inputs
221
  gr.Markdown("## Enter your API keys")
222
- hf_token = gr.Textbox(label="Hugging Face Token", type="password", value = "Required")
223
- tavily_key = gr.Textbox(label="Tavily API Key", type="password", value = "Required")
224
- google_key = gr.Textbox(label="Google Fact Check API Key", type="password", value = "Required")
225
 
226
  # Enable button when all fields filled
227
  def enable_button(hf, tavily, google):
@@ -240,7 +276,7 @@ with gr.Blocks() as demo:
240
 
241
  with gr.Tab("Image Detector"):
242
  img_input = gr.Image(type="pil", label="Upload Image")
243
- img_output = gr.Textbox(label="Model Output", lines=8)
244
  img_button = gr.Button("Classify Image")
245
  img_button.click(classify_image, inputs=img_input, outputs=img_output)
246
 
 
1
  import os
2
  import torch
3
+ from torch import nn
4
  import json
5
  import requests
6
  import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, CLIPProcessor, CLIPModel
8
+ from collections import OrderedDict
 
9
  import wikipedia
10
  import wikipediaapi
11
  import re
 
14
  from tavily import TavilyClient
15
  from huggingface_hub import InferenceClient
16
 
17
+ class CLIPImageClassifier(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
21
+ self.classifier = nn.Sequential(
22
+ nn.Linear(self.clip.config.vision_config.hidden_size, 256),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.5),
25
+ nn.Linear(256, 1),
26
+ nn.Sigmoid()
27
+ )
28
+
29
+ def forward(self, pixel_values):
30
+ feats = self.clip.vision_model(pixel_values=pixel_values).pooler_output
31
+ return self.classifier(feats)
32
+
33
  text_classifier = None
34
+ image_classifier = None
35
  TAVILY_KEY = None
36
  GOOGLE_KEY = None
37
  HF_TOKEN = None
 
40
  explain_model = "meta-llama/Llama-3.1-8B-Instruct"
41
  text_model = "rajyalakshmijampani/fever_finetuned_deberta"
42
  wiki = wikipediaapi.Wikipedia(language='en', user_agent='fact-checker/1.0')
43
+ image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
44
 
45
  def get_text_classifier():
46
  global text_classifier
 
50
  text_classifier = pipeline("text-classification", model=seq_clf, tokenizer=tokenizer)
51
  return text_classifier
52
 
53
+ def get_image_classifier():
54
+ global image_classifier
55
+ if image_classifier is None:
56
+ url = "https://huggingface.co/rajyalakshmijampani/finetuned_clip/resolve/main/best_clip_finetuned_classifier.pth"
57
+ path = "best_clip_finetuned_classifier.pth"
58
+
59
+ if not os.path.exists(path):
60
+ r = requests.get(url)
61
+ with open(path, "wb") as f:
62
+ f.write(r.content)
63
+
64
+ image_classifier = CLIPImageClassifier()
65
+ state = torch.load(path, map_location="cpu")
66
+ clean_state = OrderedDict(
67
+ (k[7:], v) if k.startswith("module.") else (k, v)
68
+ for k, v in state.items()
69
+ )
70
+ image_classifier.load_state_dict(clean_state, strict=False)
71
+ image_classifier.eval()
72
+ return image_classifier
73
+
74
+ return image_classifier
75
+
76
  def _rank_sentences(claim, sentences, top_k=4):
77
  if not sentences: return []
78
  emb_c = embed_model.encode([claim])
 
156
  evid = [e for e in evid if len(e.strip()) > 10]
157
  return (evid or ["Error: No relevant evidence found."])[:k]
158
 
159
+ # ---Text Classification Function ---
160
  def classify_text(claim, hf_token, tavily_key, google_key):
161
 
162
  global HF_TOKEN, TAVILY_KEY, GOOGLE_KEY
 
225
  return formatted_output.strip()
226
 
227
 
228
+ # ---- Image classification Function ----
229
+ def classify_image(image):
230
+ global image_processor
231
+ classifier = get_image_classifier()
232
+ try:
233
+ inputs = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"]
234
+ with torch.no_grad():
235
+ output = classifier(inputs)
236
+ p = output.item()
237
+ label = "Fake" if p > 0.5 else "Real"
238
+ return f"**Prediction:** {label}\n**Confidence score:** {p:.2f}"
239
+ except Exception as e:
240
+ return f"Error: {e}"
 
 
 
 
 
241
 
242
  # -------------------
243
  # UI Layout (Gradio)
 
255
 
256
  with gr.Column(scale=1): # Right half — user token inputs
257
  gr.Markdown("## Enter your API keys")
258
+ hf_token = gr.Textbox(label="Hugging Face Token 🔴", type="password")
259
+ tavily_key = gr.Textbox(label="Tavily API Key 🔴", type="password")
260
+ google_key = gr.Textbox(label="Google Fact Check API Key 🔴", type="password")
261
 
262
  # Enable button when all fields filled
263
  def enable_button(hf, tavily, google):
 
276
 
277
  with gr.Tab("Image Detector"):
278
  img_input = gr.Image(type="pil", label="Upload Image")
279
+ img_output = gr.Markdown(label="Model Output", value="Results will appear here...")
280
  img_button = gr.Button("Classify Image")
281
  img_button.click(classify_image, inputs=img_input, outputs=img_output)
282
 
requirements.txt CHANGED
@@ -2,7 +2,6 @@ gradio
2
  torch
3
  transformers
4
  requests
5
- Pillow
6
  wikipedia-api
7
  wikipedia
8
  sentence-transformers
 
2
  torch
3
  transformers
4
  requests
 
5
  wikipedia-api
6
  wikipedia
7
  sentence-transformers