rajyalakshmijampani commited on
Commit
b758e97
·
1 Parent(s): bd95f0c

gradio fixes

Browse files
Files changed (1) hide show
  1. app.py +108 -60
app.py CHANGED
@@ -1,72 +1,120 @@
 
 
 
 
1
  import gradio as gr
2
- import torch, zipfile, os, tempfile, requests
3
  from PIL import Image
 
4
 
5
- # --- Download helper ---
6
- def download_file_from_google_drive(url, dest):
7
- if os.path.exists(dest): # cached
8
- return dest
9
- print(f"⬇️ Downloading from {url}")
10
- r = requests.get(url, allow_redirects=True)
11
- with open(dest, "wb") as f:
12
- f.write(r.content)
13
- return dest
14
 
15
- # --- Text model loader ---
16
- def load_text_model():
17
- zip_path = download_file_from_google_drive(
18
- "https://drive.google.com/uc?export=download&id=1Sf2DoVaYBqBcdvonf6GJpo_bLWATSgeq",
19
- "text_model.zip")
20
- with zipfile.ZipFile(zip_path, 'r') as z:
21
- z.extractall("text_model")
22
- # example loading — replace with your own
23
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
24
- tokenizer = AutoTokenizer.from_pretrained("text_model")
25
- model = AutoModelForSequenceClassification.from_pretrained("text_model")
26
- return tokenizer, model
27
 
28
- # --- Image model loader ---
29
- def load_image_model():
30
- path = download_file_from_google_drive(
31
- "https://drive.google.com/uc?export=download&id=19xRLjNtGWty9loc0_6LPjIYOl-EIf2bm",
32
- "image_model.pth")
33
- model = torch.load(path, map_location="cpu")
34
- model.eval()
35
- return model
36
 
37
- # Lazy caching
38
- tokenizer, text_model, image_model = None, None, None
39
 
40
- # --- Text classification ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def classify_text(claim):
42
- global tokenizer, text_model
43
- if tokenizer is None:
44
- tokenizer, text_model = load_text_model()
45
- # (Fake retrieval for now)
46
- evidences = ["Evidence 1", "Evidence 2", "Evidence 3"]
47
- inp = claim + " " + " ".join(evidences)
48
- inputs = tokenizer(inp, return_tensors="pt", truncation=True)
49
- out = text_model(**inputs).logits
50
- label = out.argmax(-1).item()
51
- label = "REAL" if label == 1 else "FAKE"
52
- return f"{label}\n\nTop evidences:\n" + "\n".join(evidences)
53
 
54
- # --- Image classification ---
 
 
55
  def classify_image(img):
56
- global image_model
57
- if image_model is None:
58
- image_model = load_image_model()
59
- # preprocess and predict (dummy)
60
- # You’ll replace this with your transforms + inference
61
- return "REAL" if torch.rand(1).item() > 0.5 else "FAKE"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # --- Gradio UI ---
64
- demo = gr.Interface(
65
- fn=[classify_text, classify_image],
66
- inputs=[gr.Textbox(label="Enter claim text"), gr.Image(type="pil", label="Upload image")],
67
- outputs=[gr.Textbox(label="Text Result"), gr.Textbox(label="Image Result")],
68
- title="Text & Image Real/Fake Classifier"
69
- )
70
 
71
- if __name__ == "__main__":
72
- demo.launch()
 
1
+ import os
2
+ import torch
3
+ import zipfile
4
+ import requests
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  from PIL import Image
8
+ from io import BytesIO
9
 
10
+ # -------------------
11
+ # Utility: Download from Google Drive
12
+ # -------------------
13
+ def download_from_drive(drive_url, dest_path):
14
+ if os.path.exists(dest_path):
15
+ print(f"✅ Found {dest_path}, skipping download.")
16
+ return dest_path
 
 
17
 
18
+ print(f"⬇️ Downloading {drive_url} ...")
19
+ file_id = drive_url.split("id=")[-1].split("&")[0]
20
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
21
+ response = requests.get(download_url)
22
+ with open(dest_path, "wb") as f:
23
+ f.write(response.content)
24
+ print(f"✅ Saved to {dest_path}")
25
+ return dest_path
 
 
 
 
26
 
27
+ # -------------------
28
+ # Download models (modify these links!)
29
+ # -------------------
30
+ TEXT_MODEL_ZIP_URL = "https://drive.google.com/uc?export=download&id=1WUB7JzrhWXFBFFsKn6PAKh_4F3410NPZ"
31
+ IMAGE_MODEL_URL = "https://drive.google.com/uc?export=download&id=1WUB7JzrhWXFBFFsKn6PAKh_4F3410NPZ"
 
 
 
32
 
33
+ os.makedirs("models", exist_ok=True)
 
34
 
35
+ # Text model
36
+ zip_path = download_from_drive(TEXT_MODEL_ZIP_URL, "models/text_model.zip")
37
+ if not os.path.exists("models/text_model"):
38
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
39
+ zip_ref.extractall("models/text_model")
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained("models/text_model")
42
+ text_model = AutoModelForSequenceClassification.from_pretrained("models/text_model")
43
+
44
+ # Image model
45
+ image_model_path = download_from_drive(IMAGE_MODEL_URL, "models/image_model.pth")
46
+ image_model = torch.load(image_model_path, map_location=torch.device("cpu"))
47
+ image_model.eval()
48
+
49
+ # -------------------
50
+ # Tavily evidence retrieval (mocked if no key)
51
+ # -------------------
52
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
53
+
54
+ def get_top3_evidence(claim):
55
+ if not TAVILY_API_KEY:
56
+ return ["Tavily API key not set. Using dummy evidences."]
57
+ try:
58
+ response = requests.post(
59
+ "https://api.tavily.com/search",
60
+ headers={"Authorization": f"Bearer {TAVILY_API_KEY}"},
61
+ json={"query": claim, "num_results": 3},
62
+ )
63
+ data = response.json()
64
+ results = [r["content"] for r in data.get("results", [])][:3]
65
+ return results
66
+ except Exception as e:
67
+ return [f"Error getting evidence: {str(e)}"]
68
+
69
+ # -------------------
70
+ # Text classification
71
+ # -------------------
72
  def classify_text(claim):
73
+ evidences = get_top3_evidence(claim)
74
+ full_input = claim + " " + " ".join(evidences)
75
+ inputs = tokenizer(full_input, return_tensors="pt", truncation=True, padding=True)
76
+ outputs = text_model(**inputs)
77
+ preds = torch.softmax(outputs.logits, dim=1)
78
+ label = torch.argmax(preds).item()
79
+ label_str = "REAL" if label == 1 else "FAKE"
80
+ explanation = f"Based on the retrieved evidences and model prediction, this claim is **{label_str}**."
81
+ return f"Prediction: {label_str}\n\nTop Evidences:\n" + "\n".join(evidences) + f"\n\nExplanation:\n{explanation}"
 
 
82
 
83
+ # -------------------
84
+ # Image classification
85
+ # -------------------
86
  def classify_image(img):
87
+ if img is None:
88
+ return "Please upload an image."
89
+ transform = torch.nn.Sequential(
90
+ torch.nn.Identity() # 👈 replace with actual transforms if needed
91
+ )
92
+ img_tensor = torch.tensor(
93
+ [list(img.resize((224, 224)).getdata())], dtype=torch.float32
94
+ ).view(1, 224, 224, 3).permute(0, 3, 1, 2) / 255.0
95
+ with torch.no_grad():
96
+ output = image_model(img_tensor)
97
+ preds = torch.softmax(output, dim=1)
98
+ label = torch.argmax(preds).item()
99
+ label_str = "REAL" if label == 1 else "FAKE"
100
+ return f"Prediction: {label_str}\n\nExplanation: The image model classifies this as {label_str.lower()} based on learned patterns."
101
+
102
+ # -------------------
103
+ # UI Layout (Gradio)
104
+ # -------------------
105
+ with gr.Blocks() as demo:
106
+ gr.Markdown("# 🧠 Multimodal Misinformation Detector")
107
+
108
+ with gr.Tab("Text Detector"):
109
+ claim = gr.Textbox(label="Enter Claim")
110
+ text_output = gr.Textbox(label="Model Output", lines=8)
111
+ text_button = gr.Button("Classify Claim")
112
+ text_button.click(classify_text, inputs=claim, outputs=text_output)
113
 
114
+ with gr.Tab("Image Detector"):
115
+ img_input = gr.Image(type="pil", label="Upload Image")
116
+ img_output = gr.Textbox(label="Model Output", lines=6)
117
+ img_button = gr.Button("Classify Image")
118
+ img_button.click(classify_image, inputs=img_input, outputs=img_output)
 
 
119
 
120
+ demo.launch()