gyigit commited on
Commit
54e8a79
·
1 Parent(s): b61672d
Files changed (47) hide show
  1. app.py +446 -0
  2. ckpts/model.pt +3 -0
  3. src/__init__.py +0 -0
  4. src/__pycache__/__init__.cpython-311.pyc +0 -0
  5. src/data_loader/__init__.py +0 -0
  6. src/data_loader/__pycache__/__init__.cpython-311.pyc +0 -0
  7. src/data_loader/__pycache__/download_data.cpython-311.pyc +0 -0
  8. src/data_loader/__pycache__/download_images.cpython-311.pyc +0 -0
  9. src/data_loader/download_data.py +76 -0
  10. src/data_loader/download_data_mocheg.py +71 -0
  11. src/data_loader/download_images.py +168 -0
  12. src/data_loader/preprocess_embeddings.py +129 -0
  13. src/demo/__init__.py +0 -0
  14. src/demo/__pycache__/__init__.cpython-311.pyc +0 -0
  15. src/demo/__pycache__/app.cpython-311.pyc +0 -0
  16. src/demo/app.py +446 -0
  17. src/evidence/__init__.py +0 -0
  18. src/evidence/__pycache__/__init__.cpython-311.pyc +0 -0
  19. src/evidence/__pycache__/corpus_utils.cpython-311.pyc +0 -0
  20. src/evidence/__pycache__/im2im_retrieval.cpython-311.pyc +0 -0
  21. src/evidence/__pycache__/text2text_retrieval.cpython-311.pyc +0 -0
  22. src/evidence/corpus_utils.py +100 -0
  23. src/evidence/im2im_retrieval.py +169 -0
  24. src/evidence/text2text_retrieval.py +203 -0
  25. src/experimental/__init__.py +0 -0
  26. src/experimental/dataset_search.ipynb +0 -0
  27. src/experimental/dataset_stats.ipynb +0 -0
  28. src/experimental/image_captioning.ipynb +96 -0
  29. src/model/__init__.py +0 -0
  30. src/model/__pycache__/__init__.cpython-311.pyc +0 -0
  31. src/model/__pycache__/layers.cpython-311.pyc +0 -0
  32. src/model/__pycache__/model.cpython-311.pyc +0 -0
  33. src/model/dataset.py +164 -0
  34. src/model/layers.py +58 -0
  35. src/model/model.py +432 -0
  36. src/preprocess/__init__.py +0 -0
  37. src/preprocess/__pycache__/__init__.cpython-311.pyc +0 -0
  38. src/preprocess/__pycache__/caption.cpython-311.pyc +0 -0
  39. src/preprocess/__pycache__/preprocess.cpython-311.pyc +0 -0
  40. src/preprocess/caption.py +129 -0
  41. src/preprocess/preprocess.py +82 -0
  42. src/utils/__init__.py +0 -0
  43. src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  44. src/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  45. src/utils/__pycache__/path_utils.cpython-311.pyc +0 -0
  46. src/utils/data_utils.py +73 -0
  47. src/utils/path_utils.py +6 -0
app.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
+ import pandas as pd
5
+ import os
6
+
7
+ from evaluate import MisinformationPredictor
8
+ from src.evidence.im2im_retrieval import ImageCorpus
9
+ from src.evidence.text2text_retrieval import SemanticSimilarity
10
+ from src.utils.path_utils import get_project_root
11
+ from typing import List, Optional, Tuple
12
+ from dataclasses import dataclass
13
+
14
+ # Initialize BLIP model and processor
15
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
16
+ model = BlipForConditionalGeneration.from_pretrained(
17
+ "Salesforce/blip-image-captioning-large"
18
+ )
19
+
20
+ PROJECT_ROOT = get_project_root()
21
+
22
+
23
+ @dataclass
24
+ class Evidence:
25
+ evidence_id: str
26
+ dataset: str
27
+ text: Optional[str]
28
+ image: Optional[Image.Image]
29
+ caption: Optional[str]
30
+ image_path: Optional[str]
31
+ classification_result_all: Optional[Tuple[str, str, str, str]] = None
32
+ classification_result_final: Optional[str] = None
33
+
34
+
35
+ CLASSIFICATION_CATEGORIES = ["support", "refute", "not_enough_information"]
36
+
37
+
38
+ def generate_caption(image: Image.Image) -> str:
39
+ """Generates a caption for a given image."""
40
+ try:
41
+ with st.spinner("Generating caption..."):
42
+ inputs = processor(image, return_tensors="pt")
43
+ output = model.generate(**inputs)
44
+ return processor.decode(output[0], skip_special_tokens=True)
45
+ except Exception as e:
46
+ st.error(f"Error generating caption: {e}")
47
+ return ""
48
+
49
+
50
+ def enrich_text_with_caption(text: str, image_caption: str) -> str:
51
+ """Appends the image caption to the given text."""
52
+ if image_caption:
53
+ return f"{text}. {image_caption}"
54
+ return text
55
+
56
+
57
+ @st.cache_data
58
+ def get_train_df():
59
+ data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
60
+ train_csv_path = os.path.join(data_dir, "train_enriched.csv")
61
+ return pd.read_csv(train_csv_path)
62
+
63
+
64
+ @st.cache_data
65
+ def get_test_df():
66
+ data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
67
+ train_csv_path = os.path.join(data_dir, "test_enriched.csv")
68
+ return pd.read_csv(train_csv_path)
69
+
70
+
71
+ @st.cache_data
72
+ def get_semantic_similarity(
73
+ train_embeddings_file: str,
74
+ test_embeddings_file: str,
75
+ train_df: pd.DataFrame,
76
+ test_df: pd.DataFrame,
77
+ ):
78
+ return SemanticSimilarity(
79
+ train_embeddings_file=train_embeddings_file,
80
+ test_embeddings_file=test_embeddings_file,
81
+ train_df=train_df,
82
+ test_df=test_df,
83
+ )
84
+
85
+
86
+ def retrieve_evidences_by_text(
87
+ query: str,
88
+ top_k: int = 5,
89
+ ) -> List[Evidence]:
90
+ """
91
+ Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
92
+
93
+ Args:
94
+ query (str): The query text to perform the search.
95
+ top_k (int): Number of top results to retrieve.
96
+
97
+ Returns:
98
+ List[Evidence]: A list of retrieved evidence objects.
99
+ """
100
+ train_embeddings_file = os.path.join(PROJECT_ROOT, "train_embeddings.h5")
101
+ test_embeddings_file = os.path.join(PROJECT_ROOT, "test_embeddings.h5")
102
+ similarity = get_semantic_similarity(
103
+ train_embeddings_file=train_embeddings_file,
104
+ test_embeddings_file=test_embeddings_file,
105
+ train_df=get_train_df(),
106
+ test_df=get_test_df(),
107
+ )
108
+ evidences = []
109
+ try:
110
+ # Perform semantic search across both train and test datasets
111
+ results = similarity.search(query=query, top_k=top_k)
112
+
113
+ # Retrieve evidence rows based on the search results
114
+ for evidence_id, score in results:
115
+ # Determine whether the ID belongs to train or test set
116
+ if evidence_id.startswith("train_"):
117
+ df = similarity.train_csv
118
+ elif evidence_id.startswith("test_"):
119
+ df = similarity.test_csv
120
+ else:
121
+ continue # Skip invalid IDs
122
+
123
+ # Extract the row by ID
124
+ row = df[df["id"] == int(evidence_id.split("_")[1])].iloc[0]
125
+ evidence_text = row.get("evidence_enriched")
126
+ evidence_image_caption = row.get("evidence_image_caption")
127
+ evidence_image_path = row.get("evidence_image")
128
+ evidence_image = None
129
+ full_image_path = None
130
+
131
+ # Load the image if a valid path is provided
132
+ if pd.notna(evidence_image_path):
133
+ full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
134
+ try:
135
+ evidence_image = Image.open(full_image_path).convert("RGB")
136
+ except Exception as e:
137
+ st.error(f"Failed to load image {evidence_image_path}: {e}")
138
+
139
+ evidence_id_number = evidence_id.split("_")[1]
140
+ evidence_dataset = evidence_id.split("_")[0]
141
+
142
+ # Create an Evidence object
143
+ evidences.append(
144
+ Evidence(
145
+ text=evidence_text,
146
+ image=evidence_image,
147
+ caption=evidence_image_caption,
148
+ evidence_id=evidence_id_number,
149
+ dataset=evidence_dataset,
150
+ image_path=full_image_path,
151
+ )
152
+ )
153
+ except Exception as e:
154
+ st.error(f"Error performing semantic search: {e}")
155
+
156
+ return evidences
157
+
158
+
159
+ @st.cache_data
160
+ def get_image_corpus(image_features):
161
+ return ImageCorpus(image_features)
162
+
163
+
164
+ def retrieve_evidences_by_image(
165
+ image_path: str,
166
+ top_k: int = 5,
167
+ ) -> List[Evidence]:
168
+ """
169
+ Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
170
+
171
+ Args:
172
+ query (str): The query text to perform the search.
173
+ top_k (int): Number of top results to retrieve.
174
+
175
+ Returns:
176
+ List[Evidence]: A list of retrieved evidence objects.
177
+ """
178
+ image_features = os.path.join(PROJECT_ROOT, "evidence_features.pkl")
179
+ image_corpus = get_image_corpus(image_features)
180
+ evidences = []
181
+ try:
182
+ # Perform semantic search across both train and test datasets
183
+ results = image_corpus.retrieve_similar_images(image_path, top_k=top_k)
184
+
185
+ # Retrieve evidence rows based on the search results
186
+ for evidence_path, score in results:
187
+ evidence_id = evidence_path.split("/")[-1]
188
+ evidence_id_number = evidence_id.split("_")[0]
189
+ # Determine whether the ID belongs to train or test set
190
+ if "train" in evidence_path:
191
+ df = get_train_df()
192
+ elif "test" in evidence_path:
193
+ df = get_test_df()
194
+ else:
195
+ continue # Skip invalid IDs
196
+
197
+ # Extract the row by ID
198
+ row = df[df["id"] == int(evidence_id_number)].iloc[0]
199
+ evidence_text = row.get("evidence_enriched")
200
+ evidence_image_caption = row.get("evidence_image_caption")
201
+ evidence_image_path = row.get("evidence_image")
202
+ evidence_image = None
203
+ full_image_path = None
204
+
205
+ # Load the image if a valid path is provided
206
+ if pd.notna(evidence_image_path):
207
+ full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
208
+ try:
209
+ evidence_image = Image.open(full_image_path).convert("RGB")
210
+ except Exception as e:
211
+ st.error(f"Failed to load image {evidence_image_path}: {e}")
212
+
213
+ # Create an Evidence object
214
+ evidences.append(
215
+ Evidence(
216
+ text=evidence_text,
217
+ image=evidence_image,
218
+ caption=evidence_image_caption,
219
+ dataset=evidence_path.split("/")[-2],
220
+ evidence_id=evidence_id_number,
221
+ image_path=full_image_path,
222
+ )
223
+ )
224
+ except Exception as e:
225
+ st.error(f"Error performing semantic search: {e}")
226
+
227
+ return evidences
228
+
229
+
230
+ @st.cache_resource
231
+ def get_predictor():
232
+ return MisinformationPredictor(model_path="ckpts/model.pt", device="cpu")
233
+
234
+
235
+ def classify_evidence(
236
+ claim_text: str, claim_image_path: str, evidence_text: str, evidence_image_path: str
237
+ ) -> Tuple[str, str, str, str]:
238
+ """Assigns a random classification to each evidence."""
239
+ predictor = get_predictor()
240
+ predictions = predictor.evaluate(
241
+ claim_text, claim_image_path, evidence_text, evidence_image_path
242
+ )
243
+ if predictions:
244
+ return (
245
+ predictions.get("text_text", "not_enough_information"),
246
+ predictions.get("text_image", "not_enough_information"),
247
+ predictions.get("image_text", "not_enough_information"),
248
+ predictions.get("image_image", "not_enough_information"),
249
+ )
250
+ else:
251
+ return (
252
+ "not_enough_information",
253
+ "not_enough_information",
254
+ "not_enough_information",
255
+ "not_enough_information",
256
+ )
257
+
258
+
259
+ def display_evidence_tab(evidences: List[Evidence], tab_label: str):
260
+ """Displays evidence in a tabbed format."""
261
+ with st.container():
262
+ for index, evidence in enumerate(evidences):
263
+ with st.container():
264
+ st.subheader(f"Evidence {index + 1}")
265
+ st.write(f"Evidence Dataset: {evidence.dataset}")
266
+ st.write(f"Evidence ID: {evidence.evidence_id}")
267
+ if evidence.image:
268
+ st.image(
269
+ evidence.image,
270
+ caption="Evidence Image",
271
+ use_container_width=True,
272
+ )
273
+ st.text_area(
274
+ "Evidence Caption",
275
+ value=evidence.caption or "No caption available.",
276
+ height=100,
277
+ key=f"caption_{tab_label}_{index}",
278
+ disabled=True,
279
+ )
280
+ st.text_area(
281
+ "Evidence Text",
282
+ value=evidence.text or "No text available.",
283
+ height=100,
284
+ key=f"text_{tab_label}_{index}",
285
+ disabled=True,
286
+ )
287
+ if evidence.classification_result_all:
288
+ st.write("**Classification:**")
289
+ st.write(f"**text|text:** {evidence.classification_result_all[0]}")
290
+ st.write(f"**text|image:** {evidence.classification_result_all[1]}")
291
+ st.write(f"**image|text:** {evidence.classification_result_all[2]}")
292
+ st.write(
293
+ f"**image|image:** {evidence.classification_result_all[3]}"
294
+ )
295
+ st.write(
296
+ f"**Final classification result:** {evidence.classification_result_final}"
297
+ )
298
+
299
+
300
+ def get_final_classification(results: Tuple[str, str, str, str]) -> str:
301
+ text_text = results[0]
302
+ text_image = results[1]
303
+ image_text = results[2]
304
+ image_image = results[3]
305
+
306
+ # Helper function to determine the final classification based on two inputs
307
+ def resolve_classification(val1: str, val2: str) -> str:
308
+ if val1 == val2 and val1 in {"support", "refute"}:
309
+ return val1
310
+ if (val1 in {"support", "refute"} and val2 == "not_enough_information") or (
311
+ val2 in {"support", "refute"} and val1 == "not_enough_information"
312
+ ):
313
+ return val1 if val1 != "not_enough_information" else val2
314
+ return "not_enough_information"
315
+
316
+ # Step 1: Check text_text and image_image
317
+ final_result = resolve_classification(text_text, image_image)
318
+ if final_result != "not_enough_information":
319
+ return final_result
320
+
321
+ # Step 2: Check text_image and image_text
322
+ final_result = resolve_classification(text_image, image_text)
323
+ if final_result != "not_enough_information":
324
+ return final_result
325
+
326
+ # Step 3: If still undetermined, return "not_enough_information"
327
+ return "not_enough_information"
328
+
329
+
330
+ def main():
331
+ st.title("Multimodal Evidence-Based Misinformation Classification")
332
+ st.write("Upload claims that have image and/or text content to verify.")
333
+
334
+ # File uploader for images
335
+ uploaded_image = st.file_uploader(
336
+ "Upload an image (1 max)", type=["jpg", "jpeg", "png"], key="image_uploader"
337
+ )
338
+
339
+ if uploaded_image:
340
+ try:
341
+ image = Image.open(uploaded_image).convert("RGB")
342
+ st.image(image, caption="Uploaded Image", use_container_width=True)
343
+ except Exception as e:
344
+ st.error(f"Failed to display the image: {e}")
345
+
346
+ # Text input field
347
+ input_text = st.text_area("Enter text (max 4096 characters)", "", max_chars=4096)
348
+
349
+ # Sliders for top_k values
350
+ col1, col2 = st.columns(2)
351
+ with col1:
352
+ top_k_text = st.slider(
353
+ "Top-k Text Evidences", min_value=1, max_value=5, value=2, key="top_k_text"
354
+ )
355
+ with col2:
356
+ top_k_image = st.slider(
357
+ "Top-k Image Evidences",
358
+ min_value=1,
359
+ max_value=5,
360
+ value=2,
361
+ key="top_k_image",
362
+ )
363
+
364
+ # Generate Enriched Text button
365
+ if st.button("Verify Claim"):
366
+ if not uploaded_image and not input_text:
367
+ st.warning("Please upload an image or enter text.")
368
+ return
369
+
370
+ progress = st.progress(0)
371
+
372
+ # Step 1: Generate caption
373
+ progress.progress(10)
374
+ st.write("### Step 1: Generating caption...")
375
+ image_caption = ""
376
+ if uploaded_image:
377
+ image_caption = generate_caption(image)
378
+ st.write("**Generated Image Caption:**", image_caption)
379
+
380
+ # Step 2: Enrich text
381
+ progress.progress(40)
382
+ st.write("### Step 2: Enriching text...")
383
+ enriched_text = enrich_text_with_caption(input_text, image_caption)
384
+ st.write("**Enriched Text:**")
385
+ st.write(enriched_text)
386
+
387
+ # Step 3: Retrieve evidences by text
388
+ progress.progress(50)
389
+ st.write("### Step 3: Retrieving evidences by text...")
390
+ if input_text:
391
+ text_evidences = retrieve_evidences_by_text(enriched_text, top_k=top_k_text)
392
+ st.write(f"Retrieved {len(text_evidences)} text evidences.")
393
+ else:
394
+ text_evidences = None
395
+ st.write("Text modality is missing from the input claim!")
396
+
397
+ # Step 4: Retrieve evidences by image
398
+ progress.progress(70)
399
+ st.write("### Step 4: Retrieving evidences by image...")
400
+ if uploaded_image:
401
+ image_evidences = retrieve_evidences_by_image(
402
+ uploaded_image, top_k=top_k_image
403
+ )
404
+ st.write(f"Retrieved {len(image_evidences)} image evidences.")
405
+ else:
406
+ image_evidences = None
407
+ st.write("Image modality is missing from the input claim!")
408
+
409
+ # Step 5: Classify evidences
410
+ progress.progress(90)
411
+ st.write("### Step 5: Verifying claim with retrieved evidences...")
412
+ for evidence in (text_evidences or []) + (image_evidences or []):
413
+ a, b, c, d = classify_evidence(
414
+ claim_text=enriched_text,
415
+ claim_image_path=uploaded_image,
416
+ evidence_text=evidence.text,
417
+ evidence_image_path=evidence.image_path,
418
+ )
419
+ evidence.classification_result_all = a, b, c, d
420
+ evidence.classification_result_final = get_final_classification(
421
+ evidence.classification_result_all
422
+ )
423
+
424
+ # Step 6: Display evidences
425
+ progress.progress(100)
426
+ if text_evidences or image_evidences:
427
+ st.write("## Results")
428
+ tabs = st.tabs(["Text Evidences", "Image Evidences"])
429
+
430
+ with tabs[0]:
431
+ if text_evidences:
432
+ st.write("### Text Evidences")
433
+ display_evidence_tab(text_evidences, "text")
434
+ else:
435
+ st.write("Text modality is missing from the input claim!")
436
+
437
+ with tabs[1]:
438
+ if image_evidences:
439
+ st.write("### Image Evidences")
440
+ display_evidence_tab(image_evidences, "image")
441
+ else:
442
+ st.write("Image modality is missing from the input claim!")
443
+
444
+
445
+ if __name__ == "__main__":
446
+ main()
ckpts/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15237d481c551aba1df0bae16f0adf43b23ba019e138712010453bda62d39bd0
3
+ size 51850010
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
src/data_loader/__init__.py ADDED
File without changes
src/data_loader/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (202 Bytes). View file
 
src/data_loader/__pycache__/download_data.cpython-311.pyc ADDED
Binary file (4.94 kB). View file
 
src/data_loader/__pycache__/download_images.cpython-311.pyc ADDED
Binary file (8.03 kB). View file
 
src/data_loader/download_data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import gdown
4
+ from getpass import getpass
5
+ import shutil
6
+ from pathlib import Path
7
+ from src.utils.path_utils import get_project_root
8
+
9
+ # Constants
10
+ PROJECT_ROOT = get_project_root()
11
+ ZIP_FILE_PATH = str(PROJECT_ROOT / "data/raw/factify/factify_data.zip")
12
+ EXTRACTION_DIR = str(PROJECT_ROOT / "data/raw/factify/extracted")
13
+ TEMP_EXTRACTION_DIR = str(PROJECT_ROOT / "data/raw/factify/public_folder")
14
+ GDRIVE_FILE_URL = "https://drive.google.com/uc?id=1ig7XEYU1UKDHrHgDYgqiARWvNdswgFEX"
15
+
16
+
17
+ def ensure_directories():
18
+ """Ensure necessary directories exist."""
19
+ os.makedirs(os.path.dirname(ZIP_FILE_PATH), exist_ok=True)
20
+
21
+
22
+ def download_zip():
23
+ """Download the ZIP file if it doesn't already exist."""
24
+ if os.path.exists(ZIP_FILE_PATH):
25
+ print(f"Zip file already exists at {ZIP_FILE_PATH}. Skipping download...")
26
+ return
27
+ print("Downloading zip file from Google Drive...")
28
+ gdown.download(GDRIVE_FILE_URL, ZIP_FILE_PATH, quiet=False)
29
+ print(f"Downloaded zip file to {ZIP_FILE_PATH}")
30
+
31
+
32
+ def extract_zip():
33
+ """Extract the ZIP file and handle folder and file renaming."""
34
+ train_csv_path = os.path.join(EXTRACTION_DIR, "train.csv")
35
+ if os.path.exists(train_csv_path):
36
+ print(f"{train_csv_path} already exists. Skipping extraction...")
37
+ return
38
+ print("Extracting zip file...")
39
+ # Get password for the zip file
40
+ password = getpass("Enter the password for the zip file: ")
41
+ with zipfile.ZipFile(ZIP_FILE_PATH, "r") as zip_ref:
42
+ try:
43
+ zip_ref.extractall(
44
+ str(PROJECT_ROOT / "data/raw/factify/"), pwd=password.encode()
45
+ )
46
+ print(f"Extracted files to temporary folder: {TEMP_EXTRACTION_DIR}")
47
+ except RuntimeError:
48
+ print("Incorrect password. Exiting...")
49
+ exit(1)
50
+
51
+ # Remove existing extracted directory if it exists
52
+ if os.path.exists(EXTRACTION_DIR):
53
+ shutil.rmtree(EXTRACTION_DIR)
54
+ print(f"Removed existing directory: {EXTRACTION_DIR}")
55
+
56
+ # Rename extracted folder
57
+ if os.path.exists(TEMP_EXTRACTION_DIR):
58
+ os.rename(TEMP_EXTRACTION_DIR, EXTRACTION_DIR)
59
+ print(f"Renamed folder {TEMP_EXTRACTION_DIR} to {EXTRACTION_DIR}")
60
+
61
+ # Rename val.csv to test.csv
62
+ val_csv_path = os.path.join(EXTRACTION_DIR, "val.csv")
63
+ test_csv_path = os.path.join(EXTRACTION_DIR, "test.csv")
64
+ if os.path.exists(val_csv_path):
65
+ os.rename(val_csv_path, test_csv_path)
66
+ print(f"Renamed {val_csv_path} to {test_csv_path}")
67
+
68
+
69
+ def main():
70
+ ensure_directories()
71
+ download_zip()
72
+ extract_zip()
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
src/data_loader/download_data_mocheg.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tarfile
4
+ from tqdm import tqdm
5
+
6
+ DATA_URL: str = (
7
+ "http://nlplab1.cs.vt.edu/~menglong/project/multimodal/fact_checking/MOCHEG/dataset/latest_dataset/mocheg_with_tweet_2023_03.tar.gz"
8
+ )
9
+ RAW_DATA_DIR: str = "data/raw"
10
+ ARCHIVE_NAME: str = "mocheg_with_tweet_2023_03.tar.gz"
11
+ CHUNK_SIZE: int = 16 * 1024 * 1024 # 16 MB
12
+
13
+ # Ensure the raw data directory exists
14
+ os.makedirs(RAW_DATA_DIR, exist_ok=True)
15
+ archive_path: str = os.path.join(RAW_DATA_DIR, ARCHIVE_NAME)
16
+
17
+
18
+ def check_disk_space(required_space_gb: int) -> bool:
19
+ """Check if there is enough free disk space."""
20
+ stat = os.statvfs(RAW_DATA_DIR)
21
+ free_space_gb: float = (stat.f_bavail * stat.f_frsize) / (1024**3)
22
+ return free_space_gb > required_space_gb
23
+
24
+
25
+ def download_data() -> None:
26
+ """Download the data if not already present and extract it."""
27
+ # Check if the data file already exists
28
+ if os.path.exists(archive_path):
29
+ print(f"Data already downloaded at {archive_path}. Skipping download.")
30
+ return
31
+
32
+ # Ensure enough disk space (approximate)
33
+ required_space_gb: int = 80 # Adjust based on expected file size + extraction space
34
+ if not check_disk_space(required_space_gb):
35
+ print(f"Not enough disk space. At least {required_space_gb} GB required.")
36
+ return
37
+
38
+ # Download the data in larger chunks
39
+ print(f"Downloading data from {DATA_URL}...")
40
+ response = requests.get(DATA_URL, stream=True)
41
+ response.raise_for_status() # Ensure the URL is accessible
42
+
43
+ total_size: int = int(response.headers.get("content-length", 0))
44
+ with open(archive_path, "wb") as file, tqdm(
45
+ desc=ARCHIVE_NAME,
46
+ total=total_size,
47
+ unit="B",
48
+ unit_scale=True,
49
+ unit_divisor=1024,
50
+ ) as progress_bar:
51
+ for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
52
+ if chunk:
53
+ file.write(chunk)
54
+ progress_bar.update(len(chunk))
55
+
56
+ print(f"Download completed: {archive_path}")
57
+
58
+ # Extract the tar.gz file
59
+ extract_data(archive_path)
60
+
61
+
62
+ def extract_data(archive_path: str) -> None:
63
+ """Extract the downloaded tar.gz file."""
64
+ print(f"Extracting data from {archive_path}...")
65
+ with tarfile.open(archive_path, "r:gz") as tar:
66
+ tar.extractall(path=RAW_DATA_DIR)
67
+ print(f"Data extracted to {RAW_DATA_DIR}")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ download_data()
src/data_loader/download_images.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pandas as pd
4
+ import requests
5
+ import json
6
+ import io
7
+ from tqdm import tqdm
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ from collections import defaultdict
10
+ from PIL import Image
11
+
12
+ from src.utils.data_utils import HEADERS
13
+ from src.utils.path_utils import get_project_root
14
+
15
+ # Constants
16
+ PROJECT_ROOT = get_project_root()
17
+ EXTRACTION_DIR = str(PROJECT_ROOT / "data/raw/factify/extracted")
18
+ IMAGES_DIR = os.path.join(EXTRACTION_DIR, "images")
19
+
20
+
21
+ def ensure_directories(images_folder):
22
+ """Ensure the image directory exists."""
23
+ os.makedirs(images_folder, exist_ok=True)
24
+
25
+
26
+ def download_image(url, save_path):
27
+ """Download a single image if not already downloaded."""
28
+ # Check if the image already exists
29
+ if os.path.exists(save_path):
30
+ print(f"Image already exists: {save_path}")
31
+ return True
32
+
33
+ headers = {
34
+ "User-Agent": (
35
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) "
36
+ "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
37
+ )
38
+ }
39
+ try:
40
+ response = requests.get(url, headers=headers, stream=True, timeout=30)
41
+ response.raise_for_status() # Raise an error for HTTP issues
42
+ img = Image.open(io.BytesIO(response.content))
43
+ img = img.convert("RGB") # Ensure the image is in RGB format
44
+ img.save(save_path)
45
+ print(f"Downloaded and saved image: {save_path}")
46
+ return True
47
+ except Exception as e:
48
+ print(f"Failed to download image from {url}: {e}")
49
+ return False
50
+
51
+
52
+ def process_image(row, images_folder, stats, dataset_name):
53
+ """Process claim and evidence image downloads."""
54
+ file_id = str(row["id"])
55
+ category = row.get("category", "Unknown")
56
+ claim_image_url = row.get("claim_image", "")
57
+ evidence_image_url = row.get("evidence_image", "")
58
+
59
+ # Ensure category stats exist
60
+ stats["categories"].setdefault(
61
+ category,
62
+ {
63
+ "total_claim": 0,
64
+ "successful_claim": 0,
65
+ "total_evidence": 0,
66
+ "successful_evidence": 0,
67
+ },
68
+ )
69
+ stats["categories"][category]["total_claim"] += 1
70
+ stats["categories"][category]["total_evidence"] += 1
71
+
72
+ # Download claim image
73
+ if claim_image_url:
74
+ success = download_image(
75
+ claim_image_url, os.path.join(images_folder, f"{file_id}_claim.jpg")
76
+ )
77
+ if success:
78
+ stats["successful_claim"] += 1
79
+ stats["categories"][category]["successful_claim"] += 1
80
+
81
+ # Download evidence image
82
+ if evidence_image_url:
83
+ success = download_image(
84
+ evidence_image_url, os.path.join(images_folder, f"{file_id}_evidence.jpg")
85
+ )
86
+ if success:
87
+ stats["successful_evidence"] += 1
88
+ stats["categories"][category]["successful_evidence"] += 1
89
+
90
+
91
+ def download_images(dataset, use_threading):
92
+ """Download images for the specified dataset (train or test)."""
93
+ csv_path = os.path.join(EXTRACTION_DIR, f"{dataset}.csv")
94
+ images_folder = os.path.join(IMAGES_DIR, dataset)
95
+ stats_file_path = os.path.join(
96
+ EXTRACTION_DIR, f"{dataset}_image_download_stats.json"
97
+ )
98
+ ensure_directories(images_folder)
99
+
100
+ if not os.path.exists(csv_path):
101
+ print(f"CSV file not found for {dataset}: {csv_path}")
102
+ return
103
+
104
+ stats = {
105
+ "successful_claim": 0,
106
+ "successful_evidence": 0,
107
+ "categories": defaultdict(
108
+ lambda: {
109
+ "total_claim": 0,
110
+ "successful_claim": 0,
111
+ "total_evidence": 0,
112
+ "successful_evidence": 0,
113
+ }
114
+ ),
115
+ }
116
+
117
+ df = pd.read_csv(csv_path, names=HEADERS, header=None, sep="\t", skiprows=1)
118
+
119
+ if use_threading:
120
+ with ThreadPoolExecutor(max_workers=10) as executor:
121
+ futures = [
122
+ executor.submit(process_image, row, images_folder, stats, dataset)
123
+ for _, row in df.iterrows()
124
+ ]
125
+ for _ in tqdm(
126
+ as_completed(futures),
127
+ total=len(futures),
128
+ desc=f"Downloading {dataset} images",
129
+ ):
130
+ pass
131
+ else:
132
+ for _, row in tqdm(
133
+ df.iterrows(), total=len(df), desc=f"Downloading {dataset} images"
134
+ ):
135
+ process_image(row, images_folder, stats, dataset)
136
+
137
+ with open(stats_file_path, "w") as stats_file:
138
+ json.dump(stats, stats_file, indent=4)
139
+ print(f"Image download stats saved to {stats_file_path}")
140
+
141
+
142
+ def main():
143
+ parser = argparse.ArgumentParser(description="Download images for Factify dataset.")
144
+ parser.add_argument(
145
+ "--dataset",
146
+ choices=["train", "test"],
147
+ help="Specify which dataset to download images for (train or test). If not specified, both will be downloaded.",
148
+ )
149
+ parser.add_argument(
150
+ "--use-threading",
151
+ action="store_true",
152
+ default=True,
153
+ help="Enable threading for image downloads (default: True).",
154
+ )
155
+ args = parser.parse_args()
156
+
157
+ if args.dataset:
158
+ # Run for the specified dataset
159
+ download_images(args.dataset, args.use_threading)
160
+ else:
161
+ # Run for both train and test if no dataset is specified
162
+ print("No dataset specified. Downloading images for both train and test...")
163
+ for dataset in ["train", "test"]:
164
+ download_images(dataset, args.use_threading)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
src/data_loader/preprocess_embeddings.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import torch
3
+ import logging
4
+ from tqdm import tqdm
5
+ from transformers import AutoTokenizer, AutoModel, Swinv2Model
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @torch.no_grad()
11
+ def create_embeddings_h5(input_h5_path, output_h5_path, batch_size=32, device="cuda"):
12
+ """
13
+ Create a new H5 file with pre-computed embeddings from text and images.
14
+
15
+ Args:
16
+ input_h5_path (str): Path to input H5 file with raw data
17
+ output_h5_path (str): Path where to save the new H5 file with embeddings
18
+ batch_size (int): Batch size for processing
19
+ device (str): Device to use for computation
20
+ """
21
+ logger.info(f"Creating embeddings H5 file from {input_h5_path}")
22
+
23
+ # Initialize models
24
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall")
25
+ text_encoder = AutoModel.from_pretrained("microsoft/deberta-v3-xsmall").to(device)
26
+ image_encoder = Swinv2Model.from_pretrained(
27
+ "microsoft/swinv2-base-patch4-window8-256"
28
+ ).to(device)
29
+
30
+ # Set models to eval mode
31
+ text_encoder.eval()
32
+ image_encoder.eval()
33
+
34
+ # Open input H5 file
35
+ with h5py.File(input_h5_path, "r") as in_f, h5py.File(output_h5_path, "w") as out_f:
36
+ total_samples = len(in_f.keys())
37
+
38
+ # Process in batches
39
+ for batch_start in tqdm(range(0, total_samples, batch_size)):
40
+ batch_end = min(batch_start + batch_size, total_samples)
41
+ batch_indices = range(batch_start, batch_end)
42
+
43
+ # Collect batch data
44
+ claim_texts = []
45
+ doc_texts = []
46
+ claim_images = []
47
+ doc_images = []
48
+ labels = []
49
+
50
+ for idx in batch_indices:
51
+ sample = in_f[str(idx)]
52
+ claim_texts.append(sample["claim"][()].decode())
53
+ doc_texts.append(sample["document"][()].decode())
54
+ claim_images.append(torch.from_numpy(sample["claim_image"][()]))
55
+ doc_images.append(torch.from_numpy(sample["document_image"][()]))
56
+ labels.append(sample["labels"][()])
57
+
58
+ # Convert to tensors
59
+ claim_images = torch.stack(claim_images).to(device)
60
+ doc_images = torch.stack(doc_images).to(device)
61
+
62
+ # Get text embeddings with fixed sequence length
63
+ claim_text_inputs = tokenizer(
64
+ claim_texts,
65
+ truncation=True,
66
+ padding="max_length", # Changed to max_length
67
+ return_tensors="pt",
68
+ max_length=512,
69
+ ).to(device)
70
+
71
+ doc_text_inputs = tokenizer(
72
+ doc_texts,
73
+ truncation=True,
74
+ padding="max_length", # Changed to max_length
75
+ return_tensors="pt",
76
+ max_length=512,
77
+ ).to(device)
78
+
79
+ claim_text_embeds = text_encoder(**claim_text_inputs).last_hidden_state
80
+ doc_text_embeds = text_encoder(**doc_text_inputs).last_hidden_state
81
+
82
+ # Verify shapes
83
+ assert (
84
+ claim_text_embeds.shape[1] == 512
85
+ ), f"Unexpected claim text shape: {claim_text_embeds.shape}"
86
+ assert (
87
+ doc_text_embeds.shape[1] == 512
88
+ ), f"Unexpected doc text shape: {doc_text_embeds.shape}"
89
+
90
+ # Get image embeddings
91
+ claim_image_embeds = image_encoder(claim_images).last_hidden_state
92
+ doc_image_embeds = image_encoder(doc_images).last_hidden_state
93
+
94
+ # Store embeddings and labels
95
+ for batch_idx, idx in enumerate(batch_indices):
96
+ sample_group = out_f.create_group(str(idx))
97
+
98
+ # Store embeddings
99
+ sample_group.create_dataset(
100
+ "claim_text_embeds", data=claim_text_embeds[batch_idx].cpu().numpy()
101
+ )
102
+ sample_group.create_dataset(
103
+ "doc_text_embeds", data=doc_text_embeds[batch_idx].cpu().numpy()
104
+ )
105
+ sample_group.create_dataset(
106
+ "claim_image_embeds",
107
+ data=claim_image_embeds[batch_idx].cpu().numpy(),
108
+ )
109
+ sample_group.create_dataset(
110
+ "doc_image_embeds", data=doc_image_embeds[batch_idx].cpu().numpy()
111
+ )
112
+
113
+ # Store labels
114
+ sample_group.create_dataset("labels", data=labels[batch_idx])
115
+
116
+ logger.info(f"Created embeddings H5 file at {output_h5_path}")
117
+
118
+
119
+ if __name__ == "__main__":
120
+ # Set up logging
121
+ logging.basicConfig(level=logging.INFO)
122
+
123
+ # Example usage
124
+ create_embeddings_h5(
125
+ input_h5_path="data/preprocessed/train.h5",
126
+ output_h5_path="data/preprocessed/train_embeddings.h5",
127
+ batch_size=32,
128
+ device="cuda:0",
129
+ )
src/demo/__init__.py ADDED
File without changes
src/demo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (195 Bytes). View file
 
src/demo/__pycache__/app.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
src/demo/app.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
+ import pandas as pd
5
+ import os
6
+
7
+ from evaluate import MisinformationPredictor
8
+ from src.evidence.im2im_retrieval import ImageCorpus
9
+ from src.evidence.text2text_retrieval import SemanticSimilarity
10
+ from src.utils.path_utils import get_project_root
11
+ from typing import List, Optional, Tuple
12
+ from dataclasses import dataclass
13
+
14
+ # Initialize BLIP model and processor
15
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
16
+ model = BlipForConditionalGeneration.from_pretrained(
17
+ "Salesforce/blip-image-captioning-large"
18
+ )
19
+
20
+ PROJECT_ROOT = get_project_root()
21
+
22
+
23
+ @dataclass
24
+ class Evidence:
25
+ evidence_id: str
26
+ dataset: str
27
+ text: Optional[str]
28
+ image: Optional[Image.Image]
29
+ caption: Optional[str]
30
+ image_path: Optional[str]
31
+ classification_result_all: Optional[Tuple[str, str, str, str]] = None
32
+ classification_result_final: Optional[str] = None
33
+
34
+
35
+ CLASSIFICATION_CATEGORIES = ["support", "refute", "not_enough_information"]
36
+
37
+
38
+ def generate_caption(image: Image.Image) -> str:
39
+ """Generates a caption for a given image."""
40
+ try:
41
+ with st.spinner("Generating caption..."):
42
+ inputs = processor(image, return_tensors="pt")
43
+ output = model.generate(**inputs)
44
+ return processor.decode(output[0], skip_special_tokens=True)
45
+ except Exception as e:
46
+ st.error(f"Error generating caption: {e}")
47
+ return ""
48
+
49
+
50
+ def enrich_text_with_caption(text: str, image_caption: str) -> str:
51
+ """Appends the image caption to the given text."""
52
+ if image_caption:
53
+ return f"{text}. {image_caption}"
54
+ return text
55
+
56
+
57
+ @st.cache_data
58
+ def get_train_df():
59
+ data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
60
+ train_csv_path = os.path.join(data_dir, "train_enriched.csv")
61
+ return pd.read_csv(train_csv_path)
62
+
63
+
64
+ @st.cache_data
65
+ def get_test_df():
66
+ data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
67
+ train_csv_path = os.path.join(data_dir, "test_enriched.csv")
68
+ return pd.read_csv(train_csv_path)
69
+
70
+
71
+ @st.cache_data
72
+ def get_semantic_similarity(
73
+ train_embeddings_file: str,
74
+ test_embeddings_file: str,
75
+ train_df: pd.DataFrame,
76
+ test_df: pd.DataFrame,
77
+ ):
78
+ return SemanticSimilarity(
79
+ train_embeddings_file=train_embeddings_file,
80
+ test_embeddings_file=test_embeddings_file,
81
+ train_df=train_df,
82
+ test_df=test_df,
83
+ )
84
+
85
+
86
+ def retrieve_evidences_by_text(
87
+ query: str,
88
+ top_k: int = 5,
89
+ ) -> List[Evidence]:
90
+ """
91
+ Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
92
+
93
+ Args:
94
+ query (str): The query text to perform the search.
95
+ top_k (int): Number of top results to retrieve.
96
+
97
+ Returns:
98
+ List[Evidence]: A list of retrieved evidence objects.
99
+ """
100
+ train_embeddings_file = os.path.join(PROJECT_ROOT, "train_embeddings.h5")
101
+ test_embeddings_file = os.path.join(PROJECT_ROOT, "test_embeddings.h5")
102
+ similarity = get_semantic_similarity(
103
+ train_embeddings_file=train_embeddings_file,
104
+ test_embeddings_file=test_embeddings_file,
105
+ train_df=get_train_df(),
106
+ test_df=get_test_df(),
107
+ )
108
+ evidences = []
109
+ try:
110
+ # Perform semantic search across both train and test datasets
111
+ results = similarity.search(query=query, top_k=top_k)
112
+
113
+ # Retrieve evidence rows based on the search results
114
+ for evidence_id, score in results:
115
+ # Determine whether the ID belongs to train or test set
116
+ if evidence_id.startswith("train_"):
117
+ df = similarity.train_csv
118
+ elif evidence_id.startswith("test_"):
119
+ df = similarity.test_csv
120
+ else:
121
+ continue # Skip invalid IDs
122
+
123
+ # Extract the row by ID
124
+ row = df[df["id"] == int(evidence_id.split("_")[1])].iloc[0]
125
+ evidence_text = row.get("evidence_enriched")
126
+ evidence_image_caption = row.get("evidence_image_caption")
127
+ evidence_image_path = row.get("evidence_image")
128
+ evidence_image = None
129
+ full_image_path = None
130
+
131
+ # Load the image if a valid path is provided
132
+ if pd.notna(evidence_image_path):
133
+ full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
134
+ try:
135
+ evidence_image = Image.open(full_image_path).convert("RGB")
136
+ except Exception as e:
137
+ st.error(f"Failed to load image {evidence_image_path}: {e}")
138
+
139
+ evidence_id_number = evidence_id.split("_")[1]
140
+ evidence_dataset = evidence_id.split("_")[0]
141
+
142
+ # Create an Evidence object
143
+ evidences.append(
144
+ Evidence(
145
+ text=evidence_text,
146
+ image=evidence_image,
147
+ caption=evidence_image_caption,
148
+ evidence_id=evidence_id_number,
149
+ dataset=evidence_dataset,
150
+ image_path=full_image_path,
151
+ )
152
+ )
153
+ except Exception as e:
154
+ st.error(f"Error performing semantic search: {e}")
155
+
156
+ return evidences
157
+
158
+
159
+ @st.cache_data
160
+ def get_image_corpus(image_features):
161
+ return ImageCorpus(image_features)
162
+
163
+
164
+ def retrieve_evidences_by_image(
165
+ image_path: str,
166
+ top_k: int = 5,
167
+ ) -> List[Evidence]:
168
+ """
169
+ Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
170
+
171
+ Args:
172
+ query (str): The query text to perform the search.
173
+ top_k (int): Number of top results to retrieve.
174
+
175
+ Returns:
176
+ List[Evidence]: A list of retrieved evidence objects.
177
+ """
178
+ image_features = os.path.join(PROJECT_ROOT, "evidence_features.pkl")
179
+ image_corpus = get_image_corpus(image_features)
180
+ evidences = []
181
+ try:
182
+ # Perform semantic search across both train and test datasets
183
+ results = image_corpus.retrieve_similar_images(image_path, top_k=top_k)
184
+
185
+ # Retrieve evidence rows based on the search results
186
+ for evidence_path, score in results:
187
+ evidence_id = evidence_path.split("/")[-1]
188
+ evidence_id_number = evidence_id.split("_")[0]
189
+ # Determine whether the ID belongs to train or test set
190
+ if "train" in evidence_path:
191
+ df = get_train_df()
192
+ elif "test" in evidence_path:
193
+ df = get_test_df()
194
+ else:
195
+ continue # Skip invalid IDs
196
+
197
+ # Extract the row by ID
198
+ row = df[df["id"] == int(evidence_id_number)].iloc[0]
199
+ evidence_text = row.get("evidence_enriched")
200
+ evidence_image_caption = row.get("evidence_image_caption")
201
+ evidence_image_path = row.get("evidence_image")
202
+ evidence_image = None
203
+ full_image_path = None
204
+
205
+ # Load the image if a valid path is provided
206
+ if pd.notna(evidence_image_path):
207
+ full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
208
+ try:
209
+ evidence_image = Image.open(full_image_path).convert("RGB")
210
+ except Exception as e:
211
+ st.error(f"Failed to load image {evidence_image_path}: {e}")
212
+
213
+ # Create an Evidence object
214
+ evidences.append(
215
+ Evidence(
216
+ text=evidence_text,
217
+ image=evidence_image,
218
+ caption=evidence_image_caption,
219
+ dataset=evidence_path.split("/")[-2],
220
+ evidence_id=evidence_id_number,
221
+ image_path=full_image_path,
222
+ )
223
+ )
224
+ except Exception as e:
225
+ st.error(f"Error performing semantic search: {e}")
226
+
227
+ return evidences
228
+
229
+
230
+ @st.cache_resource
231
+ def get_predictor():
232
+ return MisinformationPredictor(model_path="ckpts/model.pt", device="cpu")
233
+
234
+
235
+ def classify_evidence(
236
+ claim_text: str, claim_image_path: str, evidence_text: str, evidence_image_path: str
237
+ ) -> Tuple[str, str, str, str]:
238
+ """Assigns a random classification to each evidence."""
239
+ predictor = get_predictor()
240
+ predictions = predictor.evaluate(
241
+ claim_text, claim_image_path, evidence_text, evidence_image_path
242
+ )
243
+ if predictions:
244
+ return (
245
+ predictions.get("text_text", "not_enough_information"),
246
+ predictions.get("text_image", "not_enough_information"),
247
+ predictions.get("image_text", "not_enough_information"),
248
+ predictions.get("image_image", "not_enough_information"),
249
+ )
250
+ else:
251
+ return (
252
+ "not_enough_information",
253
+ "not_enough_information",
254
+ "not_enough_information",
255
+ "not_enough_information",
256
+ )
257
+
258
+
259
+ def display_evidence_tab(evidences: List[Evidence], tab_label: str):
260
+ """Displays evidence in a tabbed format."""
261
+ with st.container():
262
+ for index, evidence in enumerate(evidences):
263
+ with st.container():
264
+ st.subheader(f"Evidence {index + 1}")
265
+ st.write(f"Evidence Dataset: {evidence.dataset}")
266
+ st.write(f"Evidence ID: {evidence.evidence_id}")
267
+ if evidence.image:
268
+ st.image(
269
+ evidence.image,
270
+ caption="Evidence Image",
271
+ use_container_width=True,
272
+ )
273
+ st.text_area(
274
+ "Evidence Caption",
275
+ value=evidence.caption or "No caption available.",
276
+ height=100,
277
+ key=f"caption_{tab_label}_{index}",
278
+ disabled=True,
279
+ )
280
+ st.text_area(
281
+ "Evidence Text",
282
+ value=evidence.text or "No text available.",
283
+ height=100,
284
+ key=f"text_{tab_label}_{index}",
285
+ disabled=True,
286
+ )
287
+ if evidence.classification_result_all:
288
+ st.write("**Classification:**")
289
+ st.write(f"**text|text:** {evidence.classification_result_all[0]}")
290
+ st.write(f"**text|image:** {evidence.classification_result_all[1]}")
291
+ st.write(f"**image|text:** {evidence.classification_result_all[2]}")
292
+ st.write(
293
+ f"**image|image:** {evidence.classification_result_all[3]}"
294
+ )
295
+ st.write(
296
+ f"**Final classification result:** {evidence.classification_result_final}"
297
+ )
298
+
299
+
300
+ def get_final_classification(results: Tuple[str, str, str, str]) -> str:
301
+ text_text = results[0]
302
+ text_image = results[1]
303
+ image_text = results[2]
304
+ image_image = results[3]
305
+
306
+ # Helper function to determine the final classification based on two inputs
307
+ def resolve_classification(val1: str, val2: str) -> str:
308
+ if val1 == val2 and val1 in {"support", "refute"}:
309
+ return val1
310
+ if (val1 in {"support", "refute"} and val2 == "not_enough_information") or (
311
+ val2 in {"support", "refute"} and val1 == "not_enough_information"
312
+ ):
313
+ return val1 if val1 != "not_enough_information" else val2
314
+ return "not_enough_information"
315
+
316
+ # Step 1: Check text_text and image_image
317
+ final_result = resolve_classification(text_text, image_image)
318
+ if final_result != "not_enough_information":
319
+ return final_result
320
+
321
+ # Step 2: Check text_image and image_text
322
+ final_result = resolve_classification(text_image, image_text)
323
+ if final_result != "not_enough_information":
324
+ return final_result
325
+
326
+ # Step 3: If still undetermined, return "not_enough_information"
327
+ return "not_enough_information"
328
+
329
+
330
+ def main():
331
+ st.title("Multimodal Evidence-Based Misinformation Classification")
332
+ st.write("Upload claims that have image and/or text content to verify.")
333
+
334
+ # File uploader for images
335
+ uploaded_image = st.file_uploader(
336
+ "Upload an image (1 max)", type=["jpg", "jpeg", "png"], key="image_uploader"
337
+ )
338
+
339
+ if uploaded_image:
340
+ try:
341
+ image = Image.open(uploaded_image).convert("RGB")
342
+ st.image(image, caption="Uploaded Image", use_container_width=True)
343
+ except Exception as e:
344
+ st.error(f"Failed to display the image: {e}")
345
+
346
+ # Text input field
347
+ input_text = st.text_area("Enter text (max 4096 characters)", "", max_chars=4096)
348
+
349
+ # Sliders for top_k values
350
+ col1, col2 = st.columns(2)
351
+ with col1:
352
+ top_k_text = st.slider(
353
+ "Top-k Text Evidences", min_value=1, max_value=5, value=2, key="top_k_text"
354
+ )
355
+ with col2:
356
+ top_k_image = st.slider(
357
+ "Top-k Image Evidences",
358
+ min_value=1,
359
+ max_value=5,
360
+ value=2,
361
+ key="top_k_image",
362
+ )
363
+
364
+ # Generate Enriched Text button
365
+ if st.button("Verify Claim"):
366
+ if not uploaded_image and not input_text:
367
+ st.warning("Please upload an image or enter text.")
368
+ return
369
+
370
+ progress = st.progress(0)
371
+
372
+ # Step 1: Generate caption
373
+ progress.progress(10)
374
+ st.write("### Step 1: Generating caption...")
375
+ image_caption = ""
376
+ if uploaded_image:
377
+ image_caption = generate_caption(image)
378
+ st.write("**Generated Image Caption:**", image_caption)
379
+
380
+ # Step 2: Enrich text
381
+ progress.progress(40)
382
+ st.write("### Step 2: Enriching text...")
383
+ enriched_text = enrich_text_with_caption(input_text, image_caption)
384
+ st.write("**Enriched Text:**")
385
+ st.write(enriched_text)
386
+
387
+ # Step 3: Retrieve evidences by text
388
+ progress.progress(50)
389
+ st.write("### Step 3: Retrieving evidences by text...")
390
+ if input_text:
391
+ text_evidences = retrieve_evidences_by_text(enriched_text, top_k=top_k_text)
392
+ st.write(f"Retrieved {len(text_evidences)} text evidences.")
393
+ else:
394
+ text_evidences = None
395
+ st.write("Text modality is missing from the input claim!")
396
+
397
+ # Step 4: Retrieve evidences by image
398
+ progress.progress(70)
399
+ st.write("### Step 4: Retrieving evidences by image...")
400
+ if uploaded_image:
401
+ image_evidences = retrieve_evidences_by_image(
402
+ uploaded_image, top_k=top_k_image
403
+ )
404
+ st.write(f"Retrieved {len(image_evidences)} image evidences.")
405
+ else:
406
+ image_evidences = None
407
+ st.write("Image modality is missing from the input claim!")
408
+
409
+ # Step 5: Classify evidences
410
+ progress.progress(90)
411
+ st.write("### Step 5: Verifying claim with retrieved evidences...")
412
+ for evidence in (text_evidences or []) + (image_evidences or []):
413
+ a, b, c, d = classify_evidence(
414
+ claim_text=enriched_text,
415
+ claim_image_path=uploaded_image,
416
+ evidence_text=evidence.text,
417
+ evidence_image_path=evidence.image_path,
418
+ )
419
+ evidence.classification_result_all = a, b, c, d
420
+ evidence.classification_result_final = get_final_classification(
421
+ evidence.classification_result_all
422
+ )
423
+
424
+ # Step 6: Display evidences
425
+ progress.progress(100)
426
+ if text_evidences or image_evidences:
427
+ st.write("## Results")
428
+ tabs = st.tabs(["Text Evidences", "Image Evidences"])
429
+
430
+ with tabs[0]:
431
+ if text_evidences:
432
+ st.write("### Text Evidences")
433
+ display_evidence_tab(text_evidences, "text")
434
+ else:
435
+ st.write("Text modality is missing from the input claim!")
436
+
437
+ with tabs[1]:
438
+ if image_evidences:
439
+ st.write("### Image Evidences")
440
+ display_evidence_tab(image_evidences, "image")
441
+ else:
442
+ st.write("Image modality is missing from the input claim!")
443
+
444
+
445
+ if __name__ == "__main__":
446
+ main()
src/evidence/__init__.py ADDED
File without changes
src/evidence/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
src/evidence/__pycache__/corpus_utils.cpython-311.pyc ADDED
Binary file (4.06 kB). View file
 
src/evidence/__pycache__/im2im_retrieval.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
src/evidence/__pycache__/text2text_retrieval.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
src/evidence/corpus_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from src.utils.path_utils import get_project_root
5
+
6
+
7
+ def separate_evidence_images(base_dir):
8
+ """
9
+ Separates evidence images from the train directory and copies them into a new 'evidence_corpus' folder.
10
+
11
+ Args:
12
+ base_dir (str): The base directory containing the 'train' folder.
13
+ """
14
+ # Define paths
15
+ datasets = ["train", "test"]
16
+ evidence_corpus_dir = os.path.join(base_dir, "evidence_corpus")
17
+
18
+ # Create the evidence_corpus directory if it doesn't exist
19
+ os.makedirs(evidence_corpus_dir, exist_ok=True)
20
+
21
+ # Loop through the train directory and copy evidence images
22
+ for dataset in datasets:
23
+ dataset_dir = os.path.join(base_dir, dataset)
24
+ for filename in os.listdir(dataset_dir):
25
+ if filename.split("_")[-1].split(".")[0] == "evidence":
26
+ new_filename = f"{dataset}_{filename}"
27
+ source_path = os.path.join(dataset_dir, filename)
28
+ target_path = os.path.join(evidence_corpus_dir, new_filename)
29
+
30
+ shutil.copy(source_path, target_path)
31
+
32
+ print("All evidence images in the train set have been copied.")
33
+
34
+
35
+ import pickle
36
+
37
+ # File path for the evidence features pickle
38
+ pickle_file_path = "evidence_features.pkl"
39
+
40
+
41
+ # Function to update the keys in the pickle
42
+ def update_pickle_keys(pickle_file_path, output_pickle_path=None):
43
+ # Open and load the existing pickle
44
+ with open(pickle_file_path, "rb") as f:
45
+ feature_dict = pickle.load(f)
46
+
47
+ updated_dict = {}
48
+
49
+ # Update each key
50
+ for old_path, features in feature_dict.items():
51
+ # Extract the filename (e.g., test_0_evidence.jpg)
52
+ filename = os.path.basename(old_path)
53
+
54
+ # Determine if it's a test or train image based on the filename
55
+ if filename.startswith("test"):
56
+ new_relative_path = os.path.join(
57
+ "data",
58
+ "raw",
59
+ "factify",
60
+ "extracted",
61
+ "images",
62
+ "test",
63
+ filename.split("_", 1)[1],
64
+ )
65
+ elif filename.startswith("train"):
66
+ new_relative_path = os.path.join(
67
+ "data",
68
+ "raw",
69
+ "factify",
70
+ "extracted",
71
+ "images",
72
+ "train",
73
+ filename.split("_", 1)[1],
74
+ )
75
+ else:
76
+ raise ValueError(f"Unexpected filename format: {filename}")
77
+
78
+ # Add the updated key and its value to the new dictionary
79
+ updated_dict[new_relative_path] = features
80
+
81
+ # Save the updated dictionary back to a pickle file
82
+ output_path = output_pickle_path if output_pickle_path else pickle_file_path
83
+ with open(output_path, "wb") as f:
84
+ pickle.dump(updated_dict, f)
85
+
86
+ print(f"Updated pickle saved at: {output_path}")
87
+
88
+
89
+ # Example usage
90
+ if __name__ == "__main__":
91
+ pickle_file_path = "/evidence_features.pkl"
92
+ project_root = get_project_root()
93
+ # Run the function
94
+ base_dir = os.path.join(
95
+ project_root, "data", "raw", "factify", "extracted", "images"
96
+ )
97
+ separate_evidence_images(base_dir)
98
+
99
+ # out_pkl_path = "C:\\Users\\defne\\Desktop\\2024-2025FallSemester\\Applied NLP\\multimodal-misinformation-detection\\data\\raw\\factify\\extracted\\images"
100
+ # update_pickle_keys(pickle_file_path, output_pickle_path=out_pkl_path)
src/evidence/im2im_retrieval.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from torchvision.models import resnet50
3
+ from torchvision.transforms import transforms
4
+ from PIL import Image
5
+ import torch.nn as nn
6
+ import torch
7
+ import pickle
8
+ import matplotlib.pyplot as plt
9
+ from src.utils.path_utils import get_project_root
10
+
11
+
12
+ class ImageSimilarity:
13
+ def __init__(self):
14
+ self.model = resnet50(weights="DEFAULT")
15
+ self.model = nn.Sequential(
16
+ *list(self.model.children())[:-1]
17
+ ) # Ignoring the last classification layer
18
+ self.model.eval()
19
+ self.transform = transforms.Compose(
20
+ [
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
25
+ ),
26
+ ]
27
+ )
28
+
29
+ def extract_features(self, image_stream):
30
+ image = Image.open(image_stream).convert("RGB")
31
+ image = self.transform(image).unsqueeze(0)
32
+
33
+ with torch.no_grad():
34
+ features = self.model(image)
35
+ features = features.flatten()
36
+ return features
37
+
38
+ def similarity(self, features1, features2):
39
+ # Calculating cosine similarity
40
+ cos = nn.CosineSimilarity(dim=1, eps=1e-6)
41
+ similarity = cos(features1.unsqueeze(0), features2.unsqueeze(0))
42
+ return similarity.item()
43
+
44
+
45
+ class ImageCorpus:
46
+ def __init__(self, feature_corpus_path):
47
+ self.feature_corpus_path = feature_corpus_path
48
+ self.feature_dict = self.load_features()
49
+ self.feature_extractor = ImageSimilarity()
50
+
51
+ def load_features(self):
52
+ try:
53
+ with open(self.feature_corpus_path, "rb") as f:
54
+ return pickle.load(f)
55
+ except (EOFError, pickle.UnpicklingError):
56
+ print(
57
+ "Warning: Pickle file is empty or corrupted. Initializing empty feature dict."
58
+ )
59
+
60
+ def save_features(self):
61
+ with open(self.feature_corpus_path, "wb") as f:
62
+ pickle.dump(self.feature_dict, f)
63
+
64
+ def add_image(self, image_path):
65
+ features = self.feature_extractor.extract_features(image_path)
66
+ self.feature_dict[image_path] = features
67
+ self.save_features()
68
+
69
+ def create_feature_corpus(self, image_dir):
70
+ for image_name in os.listdir(image_dir):
71
+ image_path = os.path.join(image_dir, image_name)
72
+ if os.path.isfile(image_path) and image_path.lower().endswith(
73
+ (".png", ".jpg", ".jpeg")
74
+ ):
75
+ features = self.feature_extractor.extract_features(image_path)
76
+ self.feature_dict[image_path] = features
77
+
78
+ self.save_features()
79
+
80
+ def retrieve_similar_images(self, query_image_path, top_k=50):
81
+ query_features = self.feature_extractor.extract_features(query_image_path)
82
+ similarity_scores = {}
83
+
84
+ for image_name, corpus_feature in self.feature_dict.items():
85
+ similarity = self.feature_extractor.similarity(
86
+ query_features, corpus_feature
87
+ )
88
+ similarity_scores[image_name] = similarity
89
+
90
+ retrieved_images = sorted(
91
+ similarity_scores.items(), key=lambda x: x[1], reverse=True
92
+ )
93
+
94
+ # Filter out identical images (based on scores)
95
+ unique_scores = set()
96
+ filtered_images = []
97
+
98
+ for image_path, score in retrieved_images:
99
+ if score not in unique_scores: # Check if this score is already added
100
+ unique_scores.add(score)
101
+ filtered_images.append((image_path, score))
102
+
103
+ if len(filtered_images) == top_k: # Stop once we have top_k unique images
104
+ break
105
+
106
+ return filtered_images
107
+
108
+
109
+ def visualize_retrieved_images(query_image_path, top_retrievals):
110
+ # Load query image
111
+
112
+ query_image = Image.open(query_image_path).convert("RGB")
113
+ project_base = get_project_root()
114
+ # Load retrieved images and their scores
115
+ retrieved_images = [
116
+ (Image.open(os.path.join(project_base, img_path)).convert("RGB"), score)
117
+ for img_path, score in top_retrievals
118
+ ]
119
+
120
+ # Set up the grid for visualization
121
+ total_retrieved = len(retrieved_images)
122
+ rows = 2 + (total_retrieved - 1) // 5 # 1 row for query + rows for 5 images per row
123
+ cols = 5
124
+
125
+ # Set figure size
126
+ plt.figure(figsize=(20, rows * 4))
127
+
128
+ # Plot query image at the top row (centered in row of 5)
129
+ plt.subplot(rows, cols, (cols // 2) + 1) # Center in the first row
130
+ plt.imshow(query_image)
131
+ plt.title("Query Image", fontsize=12)
132
+ plt.axis("off")
133
+
134
+ # Plot retrieved images
135
+ for idx, (img, score) in enumerate(retrieved_images):
136
+ plt.subplot(rows, cols, cols + idx + 1) # Start plotting after the query image
137
+ plt.imshow(img)
138
+ plt.title(f"Rank: {idx+1}\nScore: {score:.4f}", fontsize=10)
139
+ plt.axis("off")
140
+
141
+ plt.tight_layout()
142
+ plt.show()
143
+
144
+
145
+ if __name__ == "__main__":
146
+ project_root = get_project_root()
147
+ image_feature = os.path.join(project_root, "evidence_features.pkl")
148
+ image_dir = os.path.join(
149
+ project_root, "data", "raw", "factify", "extracted", "images", "evidence_corpus"
150
+ ) # Replace with your base directory path
151
+
152
+ query_image_path = os.path.join(
153
+ project_root,
154
+ "data",
155
+ "raw",
156
+ "factify",
157
+ "extracted",
158
+ "images",
159
+ "train",
160
+ "1_claim.jpg",
161
+ )
162
+
163
+ image_corpus = ImageCorpus(image_feature)
164
+ # corpus = image_corpus.create_feature_corpus(image_dir)
165
+ print(list(image_corpus.feature_dict.keys())[0])
166
+
167
+ top_retrievals = image_corpus.retrieve_similar_images(query_image_path, top_k=5)
168
+ print(top_retrievals)
169
+ visualize_retrieved_images(query_image_path, top_retrievals)
src/evidence/text2text_retrieval.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
3
+ import os
4
+ import torch
5
+ import pandas as pd
6
+
7
+ from src.utils.path_utils import get_project_root
8
+
9
+
10
+ class SemanticSimilarity:
11
+ def __init__(
12
+ self,
13
+ train_embeddings_file,
14
+ test_embeddings_file,
15
+ train_csv_path=None,
16
+ test_csv_path=None,
17
+ train_df=None,
18
+ test_df=None,
19
+ ):
20
+ # We use the Bi-Encoder to encode all passages
21
+ self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
22
+ self.bi_encoder.max_seq_length = 512 # Truncate long passages to 256 tokens
23
+
24
+ self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
25
+
26
+ self.train_embeddings, self.train_ids = self._load_embeddings(
27
+ train_embeddings_file
28
+ )
29
+ self.test_embeddings, self.test_ids = self._load_embeddings(
30
+ test_embeddings_file
31
+ )
32
+
33
+ # Load corresponding CSV files for enriched evidence
34
+ self.train_csv = (
35
+ train_df if train_df is not None else pd.read_csv(train_csv_path)
36
+ )
37
+ self.test_csv = test_df if test_df is not None else pd.read_csv(test_csv_path)
38
+
39
+ def _load_embeddings(self, h5_file_path):
40
+ """
41
+ Load embeddings and IDs from the HDF5 file
42
+ """
43
+ with h5py.File(h5_file_path, "r") as h5_file:
44
+ embeddings = torch.tensor(h5_file["embeddings"][:], dtype=torch.float16)
45
+ ids = list(h5_file["ids"][:]) # Retrieve the IDs as a list of strings
46
+
47
+ return embeddings, ids
48
+
49
+ def search(self, query, top_k):
50
+ ##### Sematic Search #####
51
+ # Encode the query using the bi-encoder and find potentially relevant passages
52
+ question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True)
53
+ question_embedding = question_embedding.to(dtype=torch.float16)
54
+ # question_embedding = question_embedding
55
+
56
+ hits_train = util.semantic_search(
57
+ question_embedding, self.train_embeddings, top_k=top_k * 5
58
+ )
59
+ hits_train = hits_train[0] # Get the hits for the first query
60
+ # print(f"len(hits_train) = {len(hits_train)}")
61
+ hits_test = util.semantic_search(
62
+ question_embedding, self.test_embeddings, top_k=top_k * 5
63
+ )
64
+ hits_test = hits_test[0]
65
+ # print(f"len(hits_test): {len(hits_test)}")
66
+
67
+ ##### Re-Ranking #####
68
+ # Now, score all retrieved passages with the cross_encoder
69
+ cross_inp_train = [
70
+ [query, self.train_csv["evidence_enriched"][hit["corpus_id"]]]
71
+ for hit in hits_train
72
+ ]
73
+ cross_scores_train = self.cross_encoder.predict(cross_inp_train)
74
+
75
+ cross_inp_test = [
76
+ [query, self.test_csv["evidence_enriched"][hit["corpus_id"]]]
77
+ for hit in hits_test
78
+ ]
79
+ cross_scores_test = self.cross_encoder.predict(cross_inp_test)
80
+
81
+ # Sort results by the cross-encoder scores
82
+ for idx in range(len(cross_scores_train)):
83
+ hits_train[idx]["cross-score"] = cross_scores_train[idx]
84
+
85
+ for idx in range(len(cross_scores_test)):
86
+ hits_test[idx]["cross-score"] = cross_scores_test[idx]
87
+
88
+ hits_train_cross_encoder = sorted(
89
+ hits_train, key=lambda x: x.get("cross-score"), reverse=True
90
+ )
91
+ hits_train_cross_encoder = hits_train_cross_encoder[: top_k * 5]
92
+ hits_test_cross_encoder = sorted(
93
+ hits_test, key=lambda x: x.get("cross-score"), reverse=True
94
+ )
95
+ hits_test_cross_encoder = hits_test_cross_encoder[: top_k * 5]
96
+
97
+ results = [
98
+ (self.train_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
99
+ for hit in hits_train_cross_encoder
100
+ ] + [
101
+ (self.test_ids[hit["corpus_id"]].decode("utf-8"), hit.get("cross-score"))
102
+ for hit in hits_test_cross_encoder
103
+ ]
104
+
105
+ ##### Filter out duplicates based on scores #####
106
+ unique_scores = set()
107
+ filtered_results = []
108
+
109
+ # print(results)
110
+ for id_, score in sorted(results, key=lambda x: x[1], reverse=True):
111
+ if score not in unique_scores:
112
+ unique_scores.add(score)
113
+ filtered_results.append((id_, score))
114
+
115
+ if (
116
+ len(filtered_results) == top_k
117
+ ): # Stop when top_k unique scores are reached
118
+ break
119
+
120
+ return filtered_results
121
+
122
+
123
+ class TextCorpus:
124
+ def __init__(self, data_dir, split):
125
+ self.bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
126
+ self.split = split # train evidences or test evidences
127
+ self.data_dir = data_dir # .csv file for enriched train and test is contained.
128
+
129
+ def encode_corpus(self):
130
+ """
131
+ Encode the corpus (evidence_enriched column for both train and test) and store the embeddings.
132
+ """
133
+ file_path = os.path.join(self.data_dir, f"{self.split}_enriched.csv")
134
+ df = pd.read_csv(file_path)
135
+
136
+ # Extract the enriched evidence column and ids
137
+ evidence_enriched = df["evidence_enriched"].tolist()
138
+ ids = df["id"].tolist() # Assuming the 'id' column is in the CSV
139
+
140
+ # Encode the evidence using the bi-encoder
141
+ embeddings = self.bi_encoder.encode(evidence_enriched, convert_to_tensor=True)
142
+
143
+ # Define HDF5 file path
144
+ h5_file_path = os.path.join(get_project_root(), f"{self.split}_embeddings.h5")
145
+
146
+ with h5py.File(h5_file_path, "w") as h5_file:
147
+ h5_file.create_dataset(
148
+ "embeddings", data=embeddings.numpy(), dtype="float16"
149
+ )
150
+
151
+ h5_file.create_dataset(
152
+ "ids",
153
+ data=[f"{self.split}_{id}" for id in ids],
154
+ dtype=h5py.string_dtype(),
155
+ )
156
+
157
+ print(f"Embeddings saved to {h5_file_path}")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ import time
162
+
163
+ start_time = time.time()
164
+ project_root = get_project_root()
165
+ data_dir = os.path.join(project_root, "data", "preprocessed")
166
+
167
+ # query = train_enriched['evidence_enriched'][0]
168
+ # train_embeddings = os.path.join(get_project_root(), 'train_evidence_embeddings.pkl')
169
+ # test_embeddings = os.path.join(get_project_root(), 'test_evidence_embeddings.pkl')
170
+
171
+ # semantic = SemanticSimilarity(train_embeddings, test_embeddings)
172
+ # semantic.search(query, top_k=10)
173
+
174
+ # evidence = TextCorpus(data_dir, 'train')
175
+
176
+ # Define file paths
177
+ train_csv_path = os.path.join(data_dir, "train_enriched.csv")
178
+ test_csv_path = os.path.join(data_dir, "test_enriched.csv")
179
+ train_embeddings_file = os.path.join(project_root, "train_embeddings.h5")
180
+ test_embeddings_file = os.path.join(project_root, "test_embeddings.h5")
181
+
182
+ # Initialize the SemanticSimilarity class
183
+ similarity = SemanticSimilarity(
184
+ train_embeddings_file=train_embeddings_file,
185
+ test_embeddings_file=test_embeddings_file,
186
+ train_csv_path=train_csv_path,
187
+ test_csv_path=test_csv_path,
188
+ )
189
+
190
+ # Load the first query from train_enriched.csv
191
+ train_df = pd.read_csv(train_csv_path)
192
+ first_query = train_df["claim_enriched"].iloc[2] # Get the first query
193
+
194
+ # Define the number of top-k results to retrieve
195
+ top_k = 5
196
+
197
+ # Perform the semantic search
198
+ results = similarity.search(query=first_query, top_k=top_k)
199
+ finish_time = time.time() - start_time
200
+ # Display the results
201
+
202
+ print(results)
203
+ print(f"Finish time: {finish_time}")
src/experimental/__init__.py ADDED
File without changes
src/experimental/dataset_search.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/experimental/dataset_stats.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/experimental/image_captioning.ipynb ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "ExecuteTime": {
9
+ "end_time": "2024-12-14T14:40:23.089485Z",
10
+ "start_time": "2024-12-14T14:40:22.937392Z"
11
+ }
12
+ },
13
+ "source": [
14
+ "import pandas as pd\n",
15
+ "from src.utils.path_utils import get_project_root\n",
16
+ "\n",
17
+ "PROJECT_ROOT = get_project_root()"
18
+ ],
19
+ "outputs": [],
20
+ "execution_count": 1
21
+ },
22
+ {
23
+ "metadata": {
24
+ "ExecuteTime": {
25
+ "end_time": "2024-12-14T14:46:49.718444Z",
26
+ "start_time": "2024-12-14T14:46:46.361765Z"
27
+ }
28
+ },
29
+ "cell_type": "code",
30
+ "source": [
31
+ "import requests\n",
32
+ "from PIL import Image\n",
33
+ "from transformers import BlipProcessor, BlipForConditionalGeneration\n",
34
+ "\n",
35
+ "processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-large\")\n",
36
+ "model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-large\")\n",
37
+ "\n",
38
+ "image = Image.open(f\"{PROJECT_ROOT}/data/scenery_image.jpg\")\n",
39
+ "\n",
40
+ "# conditional image captioning\n",
41
+ "text = \"a photography of\"\n",
42
+ "inputs = processor(image, text, return_tensors=\"pt\")\n",
43
+ "\n",
44
+ "out = model.generate(**inputs)\n",
45
+ "print(processor.decode(out[0], skip_special_tokens=True))\n",
46
+ "\n",
47
+ "# unconditional image captioning\n",
48
+ "inputs = processor(image, return_tensors=\"pt\")\n",
49
+ "\n",
50
+ "out = model.generate(**inputs)\n",
51
+ "print(processor.decode(out[0], skip_special_tokens=True))\n"
52
+ ],
53
+ "id": "80b41a616dbbafd3",
54
+ "outputs": [
55
+ {
56
+ "name": "stdout",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "a photography of a road leading to mountains with a sunset in the background\n",
60
+ "arafed road with mountains in the background and a sunset\n"
61
+ ]
62
+ }
63
+ ],
64
+ "execution_count": 8
65
+ },
66
+ {
67
+ "metadata": {},
68
+ "cell_type": "code",
69
+ "outputs": [],
70
+ "execution_count": null,
71
+ "source": "",
72
+ "id": "983b19a8aa6e4a39"
73
+ }
74
+ ],
75
+ "metadata": {
76
+ "kernelspec": {
77
+ "display_name": "Python 3",
78
+ "language": "python",
79
+ "name": "python3"
80
+ },
81
+ "language_info": {
82
+ "codemirror_mode": {
83
+ "name": "ipython",
84
+ "version": 2
85
+ },
86
+ "file_extension": ".py",
87
+ "mimetype": "text/x-python",
88
+ "name": "python",
89
+ "nbconvert_exporter": "python",
90
+ "pygments_lexer": "ipython2",
91
+ "version": "2.7.6"
92
+ }
93
+ },
94
+ "nbformat": 4,
95
+ "nbformat_minor": 5
96
+ }
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (197 Bytes). View file
 
src/model/__pycache__/layers.cpython-311.pyc ADDED
Binary file (4.08 kB). View file
 
src/model/__pycache__/model.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
src/model/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import pandas as pd
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import logging
9
+ import numpy as np
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Define preprocessing transformations
14
+ preprocess = transforms.Compose([
15
+ transforms.Resize(256),
16
+ transforms.CenterCrop(256),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225]),
19
+ ])
20
+
21
+ # Updated category mapping for multi-label classification
22
+ # Each category maps to (text-text, text-image, image-text, image-image) labels
23
+ # 0: Support, 1: NEI (Not Enough Information), 2: Refute
24
+ category_to_labels = {
25
+ 'Support_Text': [0, 1, 1, 1], # Support only for text-text
26
+ 'Support_Multimodal': [0, 0, 0, 0], # Support for all paths
27
+ 'Insufficient_Text': [1, 1, 1, 1], # NEI for all paths
28
+ 'Insufficient_Multimodal': [1, 1, 1, 0], # Support for cross-modal paths, NEI for others
29
+ 'Refute': [2, 2, 2, 2] # Refute for all paths
30
+ }
31
+
32
+ def prepare_h5_dataset(csv_path, h5_path):
33
+ """
34
+ Prepare h5 dataset from CSV file where each index contains complete sample data
35
+ """
36
+ # Create output directory if it doesn't exist
37
+ os.makedirs(os.path.dirname(h5_path), exist_ok=True)
38
+
39
+ # Read CSV file
40
+ df = pd.read_csv(csv_path, index_col=0)[['claim', 'claim_image', 'evidence', 'evidence_image', 'category']]
41
+
42
+ with h5py.File(h5_path, 'w') as f:
43
+ # Process each row
44
+ for idx, (_, row) in enumerate(df.iterrows()):
45
+ # Create group for this sample
46
+ sample_group = f.create_group(str(idx))
47
+
48
+ # Store text data
49
+ sample_group.create_dataset('claim', data=row['claim'])
50
+ sample_group.create_dataset('document', data=row['evidence'])
51
+
52
+ # Process and store images
53
+ try:
54
+ claim_img = Image.open(row['claim_image']).convert('RGB')
55
+ claim_img_tensor = preprocess(claim_img).numpy()
56
+ except Exception as e:
57
+ logger.warning(f"Error processing claim image for idx {idx}: {e}")
58
+ claim_img_tensor = np.zeros((3, 256, 256), dtype='float32')
59
+ sample_group.create_dataset('claim_image', data=claim_img_tensor)
60
+
61
+ try:
62
+ doc_img = Image.open(row['evidence_image']).convert('RGB')
63
+ doc_img_tensor = preprocess(doc_img).numpy()
64
+ except Exception as e:
65
+ logger.warning(f"Error processing evidence image for idx {idx}: {e}")
66
+ doc_img_tensor = np.zeros((3, 256, 256), dtype='float32')
67
+ sample_group.create_dataset('document_image', data=doc_img_tensor)
68
+
69
+ # Store multi-path labels
70
+ labels = category_to_labels.get(row['category'], [1, 1, 1, 1]) # Default to NEI if category not found
71
+ sample_group.create_dataset('labels', data=np.array(labels, dtype=np.int64))
72
+
73
+ logger.info(f"Created H5 dataset at {h5_path}")
74
+
75
+
76
+ class MisinformationDataset(Dataset):
77
+ def __init__(self, csv_path, pre_embed=False):
78
+ self.csv_path = csv_path
79
+ self.pre_embed = pre_embed
80
+
81
+ # Derive h5 path from csv path
82
+ base_path = os.path.splitext(csv_path)[0]
83
+ self.h5_path = base_path + '_embeddings.h5' if pre_embed else base_path + '.h5'
84
+
85
+ if not os.path.exists(self.h5_path):
86
+ if pre_embed:
87
+ raise FileNotFoundError(f"Pre-computed embeddings not found at {self.h5_path}. "
88
+ f"Please run preprocess_embeddings.py first.")
89
+ logger.info(f"H5 file not found at {self.h5_path}. Creating new H5 dataset...")
90
+ prepare_h5_dataset(self.csv_path, self.h5_path)
91
+
92
+ self.h5_file = h5py.File(self.h5_path, 'r')
93
+ self.length = len(self.h5_file.keys())
94
+
95
+ def __len__(self):
96
+ return self.length
97
+
98
+ def __getitem__(self, idx):
99
+ sample = self.h5_file[str(idx)]
100
+
101
+ if self.pre_embed:
102
+ return {
103
+ 'id': str(idx),
104
+ 'claim_text_embeds': torch.from_numpy(sample['claim_text_embeds'][()]),
105
+ 'doc_text_embeds': torch.from_numpy(sample['doc_text_embeds'][()]),
106
+ 'claim_image_embeds': torch.from_numpy(sample['claim_image_embeds'][()]),
107
+ 'doc_image_embeds': torch.from_numpy(sample['doc_image_embeds'][()]),
108
+ 'labels': torch.from_numpy(sample['labels'][()])
109
+ }
110
+ else:
111
+ return {
112
+ 'id': str(idx),
113
+ 'claim': sample['claim'][()].decode(),
114
+ 'claim_image': torch.from_numpy(sample['claim_image'][()]),
115
+ 'document': sample['document'][()].decode(),
116
+ 'document_image': torch.from_numpy(sample['document_image'][()]),
117
+ 'labels': torch.from_numpy(sample['labels'][()])
118
+ }
119
+
120
+ def __del__(self):
121
+ if hasattr(self, 'h5_file'):
122
+ self.h5_file.close()
123
+
124
+
125
+ def get_dataloader(csv_path, batch_size=32, num_workers=4, shuffle=False, pre_embed=False):
126
+ dataset = MisinformationDataset(csv_path, pre_embed=pre_embed)
127
+
128
+ dataloader = DataLoader(
129
+ dataset,
130
+ batch_size=batch_size,
131
+ shuffle=shuffle,
132
+ num_workers=num_workers,
133
+ pin_memory=True
134
+ )
135
+
136
+ return dataloader
137
+
138
+
139
+ if __name__ == "__main__":
140
+ # Set up logging
141
+ logging.basicConfig(level=logging.INFO)
142
+
143
+ # Create dataloaders
144
+ train_loader = get_dataloader('data/preprocessed/train.csv', shuffle=True)
145
+ #test_loader = get_dataloader('data/preprocessed/test.csv', shuffle=False)
146
+
147
+ # Test dataloaders
148
+ for batch in train_loader:
149
+ print("Train batch:")
150
+ print(f"Batch size: {len(batch['id'])}")
151
+ print(f"Claim shape: {batch['claim_image'].shape}")
152
+ print(f"Document image shape: {batch['document_image'].shape}")
153
+ print(f"Labels shape: {batch['labels'].shape}") # Should be (batch_size, 4)
154
+ print(f"Sample labels: {batch['labels'][0]}") # Show labels for first item
155
+ break
156
+
157
+ #for batch in test_loader:
158
+ # print("\nTest batch:")
159
+ # print(f"Batch size: {len(batch['id'])}")
160
+ # print(f"Claim shape: {batch['claim_image'].shape}")
161
+ # print(f"Document image shape: {batch['document_image'].shape}")
162
+ # print(f"Labels shape: {batch['labels'].shape}")
163
+ # print(f"Sample labels: {batch['labels'][0]}")
164
+ # break
src/model/layers.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class MLP(nn.Module):
6
+ """
7
+ MLP block with GELU activation and dropout.
8
+ """
9
+ def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.1):
10
+ super().__init__()
11
+ hidden_dim = int(embed_dim * mlp_ratio)
12
+ self.net = nn.Sequential(
13
+ nn.Linear(embed_dim, hidden_dim),
14
+ nn.GELU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(hidden_dim, embed_dim),
17
+ nn.Dropout(dropout)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.net(x)
22
+
23
+
24
+ class MultiHeadAttention(nn.Module):
25
+ """
26
+ Multi-head attention module with optional fused attention support.
27
+ """
28
+ def __init__(self, embed_dim, num_heads, dropout=0.1, fused_attn=False):
29
+ super().__init__()
30
+ self.embed_dim = embed_dim
31
+ self.num_heads = num_heads
32
+ self.dropout = dropout
33
+ self.fused_attn = fused_attn
34
+ self.attn_dropout = nn.Dropout(dropout)
35
+
36
+ def forward(self, Q, K, V, out_proj):
37
+ B, T, D = Q.shape
38
+ head_dim = D // self.num_heads
39
+
40
+ Q_ = Q.view(B, T, self.num_heads, head_dim).transpose(1, 2) # (B, num_heads, T, head_dim)
41
+ K_ = K.view(B, -1, self.num_heads, head_dim).transpose(1, 2)
42
+ V_ = V.view(B, -1, self.num_heads, head_dim).transpose(1, 2)
43
+
44
+ if self.fused_attn:
45
+ context = F.scaled_dot_product_attention(
46
+ Q_, K_, V_,
47
+ dropout_p=self.dropout if self.training else 0.0,
48
+ is_causal=False
49
+ )
50
+ else:
51
+ scores = torch.matmul(Q_, K_.transpose(-1, -2)) / (head_dim ** 0.5)
52
+ attn_weights = F.softmax(scores, dim=-1)
53
+ attn_weights = self.attn_dropout(attn_weights)
54
+ context = torch.matmul(attn_weights, V_) # (B, num_heads, T, head_dim)
55
+
56
+ context = context.transpose(1, 2).contiguous().view(B, T, D)
57
+ out = out_proj(context)
58
+ return out
src/model/model.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .layers import MLP, MultiHeadAttention
4
+
5
+
6
+ class MultiViewClaimRepresentation(nn.Module):
7
+ """
8
+ Multi-view claim representation module with transformer-like architecture
9
+ for self-attention and cross-attention in text and image modalities.
10
+ """
11
+ def __init__(self, text_input_dim=384, image_input_dim=1024, embed_dim=512, num_heads=8, dropout=0.1, mlp_ratio=4.0, fused_attn=False):
12
+ super().__init__()
13
+ self.text_input_dim = text_input_dim
14
+ self.image_input_dim = image_input_dim
15
+ self.embed_dim = embed_dim
16
+ self.num_heads = num_heads
17
+ self.dropout = dropout
18
+
19
+ self.text_proj = nn.Linear(text_input_dim, embed_dim)
20
+ self.image_proj = nn.Linear(image_input_dim, embed_dim)
21
+
22
+ # Text projections for attention
23
+ self.text_WQ = nn.Linear(embed_dim, embed_dim)
24
+ self.text_WK = nn.Linear(embed_dim, embed_dim)
25
+ self.text_WV = nn.Linear(embed_dim, embed_dim)
26
+
27
+ # Image projections for attention
28
+ self.image_WQ = nn.Linear(embed_dim, embed_dim)
29
+ self.image_WK = nn.Linear(embed_dim, embed_dim)
30
+ self.image_WV = nn.Linear(embed_dim, embed_dim)
31
+
32
+ # Output projections
33
+ self.text_self_attn_out = nn.Linear(embed_dim, embed_dim)
34
+ self.image_self_attn_out = nn.Linear(embed_dim, embed_dim)
35
+ self.text_cross_attn_out = nn.Linear(embed_dim, embed_dim)
36
+ self.image_cross_attn_out = nn.Linear(embed_dim, embed_dim)
37
+
38
+ # Layer norms
39
+ self.text_self_ln1 = nn.LayerNorm(embed_dim)
40
+ self.text_self_ln2 = nn.LayerNorm(embed_dim)
41
+ self.image_self_ln1 = nn.LayerNorm(embed_dim)
42
+ self.image_self_ln2 = nn.LayerNorm(embed_dim)
43
+ self.text_cross_ln1 = nn.LayerNorm(embed_dim)
44
+ self.text_cross_ln2 = nn.LayerNorm(embed_dim)
45
+ self.image_cross_ln1 = nn.LayerNorm(embed_dim)
46
+ self.image_cross_ln2 = nn.LayerNorm(embed_dim)
47
+
48
+ # MLPs
49
+ self.text_mlp = MLP(embed_dim, mlp_ratio, dropout)
50
+ self.image_mlp = MLP(embed_dim, mlp_ratio, dropout)
51
+
52
+ # Multi-head attention
53
+ self.attention = MultiHeadAttention(embed_dim, num_heads, dropout, fused_attn)
54
+ self.proj_dropout = nn.Dropout(dropout)
55
+
56
+ def forward(self, X_t=None, X_i=None):
57
+ """
58
+ Args:
59
+ X_t (Tensor): Text embeddings of shape (B, L_t, D)
60
+ X_i (Tensor): Image embeddings of shape (B, L_i, D)
61
+
62
+ Returns:
63
+ (H_t_fused, H_i_fused):
64
+ H_t_fused: Text representations with self- and co-attention
65
+ H_i_fused: Image representations with self- and co-attention
66
+ """
67
+ # Project inputs to embedding dimension first
68
+ if X_t is not None:
69
+ X_t = self.text_proj(X_t)
70
+ if X_i is not None:
71
+ X_i = self.image_proj(X_i)
72
+
73
+ # Pre-compute Q,K,V for both modalities if present
74
+ text_Q = self.text_WQ(X_t) if X_t is not None else None
75
+ text_K = self.text_WK(X_t) if X_t is not None else None
76
+ text_V = self.text_WV(X_t) if X_t is not None else None
77
+
78
+ image_Q = self.image_WQ(X_i) if X_i is not None else None
79
+ image_K = self.image_WK(X_i) if X_i is not None else None
80
+ image_V = self.image_WV(X_i) if X_i is not None else None
81
+
82
+ # Unimodal text case
83
+ if X_t is not None and X_i is None:
84
+ # Self attention without MLP
85
+ H_t = X_t + self.attention(text_Q, text_K, text_V, self.text_self_attn_out)
86
+ H_t = self.text_self_ln1(H_t)
87
+ # Apply MLP after self attention
88
+ H_t = H_t + self.text_mlp(H_t)
89
+ H_t = self.text_self_ln2(H_t)
90
+ return H_t, None
91
+
92
+ # Unimodal image case
93
+ if X_i is not None and X_t is None:
94
+ # Self attention without MLP
95
+ H_i = X_i + self.attention(image_Q, image_K, image_V, self.image_self_attn_out)
96
+ H_i = self.image_self_ln1(H_i)
97
+ # Apply MLP after self attention
98
+ H_i = H_i + self.image_mlp(H_i)
99
+ H_i = self.image_self_ln2(H_i)
100
+ return None, H_i
101
+
102
+ # Multimodal case
103
+ # Text processing
104
+ H_t = X_t + self.attention(text_Q, text_K, text_V, self.text_self_attn_out) # Self attention
105
+ H_t = self.text_self_ln1(H_t)
106
+ C_t = H_t + self.attention(H_t, text_K, text_V, self.text_cross_attn_out) # Cross attention
107
+ C_t = self.text_cross_ln1(C_t)
108
+ # Apply MLP after combined attention
109
+ C_t = C_t + self.text_mlp(C_t)
110
+ C_t = self.text_cross_ln2(C_t)
111
+
112
+ # Image processing
113
+ H_i = X_i + self.attention(image_Q, image_K, image_V, self.image_self_attn_out) # Self attention
114
+ H_i = self.image_self_ln1(H_i)
115
+ C_i = H_i + self.attention(H_i, image_K, image_V, self.image_cross_attn_out) # Cross attention
116
+ C_i = self.image_cross_ln1(C_i)
117
+ # Apply MLP after combined attention
118
+ C_i = C_i + self.image_mlp(C_i)
119
+ C_i = self.image_cross_ln2(C_i)
120
+
121
+ return C_t, C_i
122
+
123
+
124
+ class CrossAttentionEvidenceConditioning(nn.Module):
125
+ """
126
+ Cross-attention module to condition claim representations
127
+ on textual and visual evidence.
128
+ """
129
+ def __init__(self, text_input_dim=384, image_input_dim=1024, embed_dim=768, num_heads=8, dropout=0.1, mlp_ratio=4.0, fused_attn=False):
130
+ super().__init__()
131
+ self.num_heads = num_heads
132
+ self.embed_dim = embed_dim
133
+ self.dropout = dropout
134
+ self.fused_attn = fused_attn
135
+
136
+ # Query projections
137
+ self.text_WQ = nn.Linear(embed_dim, embed_dim)
138
+ self.image_WQ = nn.Linear(embed_dim, embed_dim)
139
+
140
+ # Text evidence projections
141
+ self.text_evidence_key = nn.Linear(text_input_dim, embed_dim)
142
+ self.text_evidence_value = nn.Linear(text_input_dim, embed_dim)
143
+
144
+ # Image evidence projections
145
+ self.image_evidence_key = nn.Linear(image_input_dim, embed_dim)
146
+ self.image_evidence_value = nn.Linear(image_input_dim, embed_dim)
147
+
148
+ # Separate output projections for each attention path
149
+ self.text_text_out = nn.Linear(embed_dim, embed_dim)
150
+ self.text_image_out = nn.Linear(embed_dim, embed_dim)
151
+ self.image_text_out = nn.Linear(embed_dim, embed_dim)
152
+ self.image_image_out = nn.Linear(embed_dim, embed_dim)
153
+
154
+ # Separate layer norms for each attention path
155
+ self.text_text_ln1 = nn.LayerNorm(embed_dim)
156
+ self.text_text_ln2 = nn.LayerNorm(embed_dim)
157
+ self.text_image_ln1 = nn.LayerNorm(embed_dim)
158
+ self.text_image_ln2 = nn.LayerNorm(embed_dim)
159
+ self.image_text_ln1 = nn.LayerNorm(embed_dim)
160
+ self.image_text_ln2 = nn.LayerNorm(embed_dim)
161
+ self.image_image_ln1 = nn.LayerNorm(embed_dim)
162
+ self.image_image_ln2 = nn.LayerNorm(embed_dim)
163
+
164
+ # MLPs
165
+ self.text_mlp = MLP(embed_dim, mlp_ratio, dropout)
166
+ self.image_mlp = MLP(embed_dim, mlp_ratio, dropout)
167
+
168
+ # Multi-head attention
169
+ self.attention = MultiHeadAttention(embed_dim, num_heads, dropout, fused_attn)
170
+ self.proj_dropout = nn.Dropout(dropout)
171
+
172
+ def forward(self, H_t=None, H_i=None, E_t=None, E_i=None):
173
+ """
174
+ Returns:
175
+ (S_t, S_i): Each contains a tuple of (text_evidence_output, image_evidence_output)
176
+ """
177
+ S_t_t, S_t_i = None, None
178
+ S_i_t, S_i_i = None, None
179
+
180
+ if H_t is not None:
181
+ # Text-to-text evidence attention
182
+ S_t_t = self.attention(
183
+ Q=self.text_WQ(H_t),
184
+ K=self.text_evidence_key(E_t),
185
+ V=self.text_evidence_value(E_t),
186
+ out_proj=self.text_text_out
187
+ )
188
+ S_t_t = H_t + S_t_t
189
+ S_t_t = self.text_text_ln1(S_t_t)
190
+ S_t_t = S_t_t + self.text_mlp(S_t_t)
191
+ S_t_t = self.text_text_ln2(S_t_t)
192
+
193
+ # Text-to-image evidence attention
194
+ S_t_i = self.attention(
195
+ Q=self.text_WQ(H_t),
196
+ K=self.image_evidence_key(E_i),
197
+ V=self.image_evidence_value(E_i),
198
+ out_proj=self.text_image_out
199
+ )
200
+ S_t_i = H_t + S_t_i
201
+ S_t_i = self.text_image_ln1(S_t_i)
202
+ S_t_i = S_t_i + self.text_mlp(S_t_i)
203
+ S_t_i = self.text_image_ln2(S_t_i)
204
+
205
+ if H_i is not None:
206
+ # Image-to-text evidence attention
207
+ S_i_t = self.attention(
208
+ Q=self.image_WQ(H_i),
209
+ K=self.text_evidence_key(E_t),
210
+ V=self.text_evidence_value(E_t),
211
+ out_proj=self.image_text_out
212
+ )
213
+ S_i_t = H_i + S_i_t
214
+ S_i_t = self.image_text_ln1(S_i_t)
215
+ S_i_t = S_i_t + self.image_mlp(S_i_t)
216
+ S_i_t = self.image_text_ln2(S_i_t)
217
+
218
+ # Image-to-image evidence attention
219
+ S_i_i = self.attention(
220
+ Q=self.image_WQ(H_i),
221
+ K=self.image_evidence_key(E_i),
222
+ V=self.image_evidence_value(E_i),
223
+ out_proj=self.image_image_out
224
+ )
225
+ S_i_i = H_i + S_i_i
226
+ S_i_i = self.image_image_ln1(S_i_i)
227
+ S_i_i = S_i_i + self.image_mlp(S_i_i)
228
+ S_i_i = self.image_image_ln2(S_i_i)
229
+
230
+ return (S_t_t, S_t_i), (S_i_t, S_i_i)
231
+
232
+
233
+ class ClassificationModule(nn.Module):
234
+ """
235
+ Classification module that takes final text/image representations
236
+ and outputs logits for {support, refute, not enough info}
237
+ for each evidence path.
238
+ """
239
+ def __init__(self, embed_dim=768, hidden_dim=256, num_classes=3, dropout=0.1):
240
+ super().__init__()
241
+ # MLPs for text representations
242
+ self.mlp_text_given_text = nn.Sequential(
243
+ nn.Linear(embed_dim, hidden_dim),
244
+ nn.ReLU(),
245
+ nn.Dropout(dropout),
246
+ nn.Linear(hidden_dim, num_classes)
247
+ )
248
+ self.mlp_text_given_image = nn.Sequential(
249
+ nn.Linear(embed_dim, hidden_dim),
250
+ nn.ReLU(),
251
+ nn.Dropout(dropout),
252
+ nn.Linear(hidden_dim, num_classes)
253
+ )
254
+
255
+ # MLPs for image representations
256
+ self.mlp_image_given_text = nn.Sequential(
257
+ nn.Linear(embed_dim, hidden_dim),
258
+ nn.ReLU(),
259
+ nn.Dropout(dropout),
260
+ nn.Linear(hidden_dim, num_classes)
261
+ )
262
+ self.mlp_image_given_image = nn.Sequential(
263
+ nn.Linear(embed_dim, hidden_dim),
264
+ nn.ReLU(),
265
+ nn.Dropout(dropout),
266
+ nn.Linear(hidden_dim, num_classes)
267
+ )
268
+
269
+ def forward(self, S_t=None, S_i=None):
270
+ """
271
+ Args:
272
+ S_t: Tuple of (text_given_text, text_given_image) representations
273
+ S_i: Tuple of (image_given_text, image_given_image) representations
274
+ Returns:
275
+ y_t: Tuple of (text_given_text_logits, text_given_image_logits)
276
+ y_i: Tuple of (image_given_text_logits, image_given_image_logits)
277
+ """
278
+ y_t_t, y_t_i = None, None
279
+ y_i_t, y_i_i = None, None
280
+
281
+ if S_t is not None:
282
+ S_t_t, S_t_i = S_t
283
+ if S_t_t is not None:
284
+ pooled_t_t = S_t_t.mean(dim=1)
285
+ y_t_t = self.mlp_text_given_text(pooled_t_t)
286
+ if S_t_i is not None:
287
+ pooled_t_i = S_t_i.mean(dim=1)
288
+ y_t_i = self.mlp_text_given_image(pooled_t_i)
289
+
290
+ if S_i is not None:
291
+ S_i_t, S_i_i = S_i
292
+ if S_i_t is not None:
293
+ pooled_i_t = S_i_t.mean(dim=1)
294
+ y_i_t = self.mlp_image_given_text(pooled_i_t)
295
+ if S_i_i is not None:
296
+ pooled_i_i = S_i_i.mean(dim=1)
297
+ y_i_i = self.mlp_image_given_image(pooled_i_i)
298
+
299
+ return (y_t_t, y_t_i), (y_i_t, y_i_i)
300
+
301
+
302
+ class MisinformationDetectionModel(nn.Module):
303
+ """
304
+ End-to-end model combining:
305
+ 1) Multi-view claim representation
306
+ 2) Cross-attention evidence conditioning
307
+ 3) Classification for each evidence path
308
+ """
309
+ def __init__(self,
310
+ text_input_dim=384, # DeBERTa-v3-xsmall hidden size
311
+ image_input_dim=1024, # Swinv2-base hidden size
312
+ embed_dim=512,
313
+ num_heads=8,
314
+ dropout=0.1,
315
+ hidden_dim=256,
316
+ num_classes=3,
317
+ mlp_ratio=4.0,
318
+ fused_attn=False):
319
+ super().__init__()
320
+
321
+ self.representation = MultiViewClaimRepresentation(
322
+ text_input_dim=text_input_dim,
323
+ image_input_dim=image_input_dim,
324
+ embed_dim=embed_dim,
325
+ num_heads=num_heads,
326
+ dropout=dropout,
327
+ mlp_ratio=mlp_ratio,
328
+ fused_attn=fused_attn
329
+ )
330
+ self.cross_attn = CrossAttentionEvidenceConditioning(
331
+ text_input_dim=text_input_dim,
332
+ image_input_dim=image_input_dim,
333
+ embed_dim=embed_dim,
334
+ num_heads=num_heads,
335
+ dropout=dropout,
336
+ mlp_ratio=mlp_ratio,
337
+ fused_attn=fused_attn
338
+ )
339
+ self.classifier = ClassificationModule(
340
+ embed_dim=embed_dim,
341
+ hidden_dim=hidden_dim,
342
+ num_classes=num_classes,
343
+ dropout=dropout
344
+ )
345
+
346
+ # Initialize weights
347
+ self._initialize_weights()
348
+
349
+ def _initialize_weights(self):
350
+ for module in self.modules():
351
+ if isinstance(module, nn.Linear):
352
+ nn.init.xavier_uniform_(module.weight)
353
+ if module.bias is not None:
354
+ nn.init.zeros_(module.bias)
355
+ elif isinstance(module, nn.LayerNorm):
356
+ nn.init.ones_(module.weight)
357
+ nn.init.zeros_(module.bias)
358
+
359
+ def forward(self, X_t=None, X_i=None, E_t=None, E_i=None):
360
+ """
361
+ Args:
362
+ X_t (Tensor): Text claim embeddings (B, L_t, D)
363
+ X_i (Tensor): Image claim embeddings (B, L_i, D)
364
+ E_t (Tensor): Text evidence embeddings (B, L_e_t, D)
365
+ E_i (Tensor): Image evidence embeddings (B, L_e_i, D)
366
+
367
+ Returns:
368
+ y_t: Tuple of (text_given_text_logits, text_given_image_logits)
369
+ y_i: Tuple of (image_given_text_logits, image_given_image_logits)
370
+ Each logit tensor has shape (B, num_classes)
371
+ """
372
+ # Get fused claim representations
373
+ H_t, H_i = self.representation(X_t, X_i)
374
+
375
+ # Get evidence-conditioned representations for each path
376
+ (S_t_t, S_t_i), (S_i_t, S_i_i) = self.cross_attn(H_t, H_i, E_t, E_i)
377
+
378
+ # Get predictions for each evidence path
379
+ (y_t_t, y_t_i), (y_i_t, y_i_i) = self.classifier(
380
+ S_t=(S_t_t, S_t_i),
381
+ S_i=(S_i_t, S_i_i)
382
+ )
383
+
384
+ return (y_t_t, y_t_i), (y_i_t, y_i_i)
385
+
386
+
387
+ if __name__ == "__main__":
388
+ # Example usage
389
+ batch_size = 2
390
+ seq_len_t = 5
391
+ seq_len_i = 7
392
+ evidence_len_t = 6
393
+ evidence_len_i = 8
394
+ embed_dim = 768
395
+
396
+ # Create random embeddings
397
+ text_claim = torch.randn(batch_size, seq_len_t, embed_dim)
398
+ image_claim = torch.randn(batch_size, seq_len_i, embed_dim)
399
+ text_evidence = torch.randn(batch_size, evidence_len_t, embed_dim)
400
+ image_evidence = torch.randn(batch_size, evidence_len_i, embed_dim)
401
+
402
+ # Build model
403
+ model = MisinformationDetectionModel(
404
+ embed_dim=embed_dim,
405
+ num_heads=8,
406
+ dropout=0.1,
407
+ hidden_dim=256,
408
+ num_classes=3
409
+ )
410
+
411
+ # Forward pass (multimodal)
412
+ (y_t_t, y_t_i), (y_i_t, y_i_i) = model(
413
+ X_t=text_claim,
414
+ X_i=image_claim,
415
+ E_t=text_evidence,
416
+ E_i=image_evidence
417
+ )
418
+ print("Text-Text logits:", y_t_t.shape) # [B, 3]
419
+ print("Text-Image logits:", y_t_i.shape) # [B, 3]
420
+ print("Image-Text logits:", y_i_t.shape) # [B, 3]
421
+ print("Image-Image logits:", y_i_i.shape) # [B, 3]
422
+
423
+ # Forward pass (unimodal text)
424
+ (y_t_t, y_t_i), (y_i_t, y_i_i) = model(
425
+ X_t=text_claim,
426
+ E_t=text_evidence
427
+ )
428
+ print("\nUnimodal Text:")
429
+ print("Text-Text logits:", y_t_t.shape if y_t_t is not None else None)
430
+ print("Text-Image logits:", y_t_i if y_t_i is not None else None)
431
+ print("Image-Text logits:", y_i_t if y_i_t is not None else None)
432
+ print("Image-Image logits:", y_i_i if y_i_i is not None else None)
src/preprocess/__init__.py ADDED
File without changes
src/preprocess/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
src/preprocess/__pycache__/caption.cpython-311.pyc ADDED
Binary file (5.88 kB). View file
 
src/preprocess/__pycache__/preprocess.cpython-311.pyc ADDED
Binary file (3.74 kB). View file
 
src/preprocess/caption.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration
7
+ from src.utils.path_utils import get_project_root
8
+
9
+ # Initialize BLIP model and processor
10
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
11
+ model = BlipForConditionalGeneration.from_pretrained(
12
+ "Salesforce/blip-image-captioning-large"
13
+ )
14
+
15
+ PROJECT_ROOT = get_project_root()
16
+ RAW_DIR = PROJECT_ROOT / "data/raw/factify"
17
+ PROCESSED_DIR = PROJECT_ROOT / "data/preprocessed"
18
+
19
+ BATCH_SIZE = 20 # Number of rows to process per batch
20
+
21
+
22
+ def generate_caption(image_path: str) -> str:
23
+ """Generates a caption for an image given its path."""
24
+ try:
25
+ image = Image.open(f"{PROJECT_ROOT}/{image_path}").convert("RGB")
26
+ inputs = processor(image, return_tensors="pt")
27
+ output = model.generate(**inputs)
28
+ return processor.decode(output[0], skip_special_tokens=True)
29
+ except Exception as e:
30
+ print(f"Error processing image {image_path}: {e}")
31
+ return ""
32
+
33
+
34
+ def process_image_row(row: pd.Series) -> Tuple[str, str, str, str]:
35
+ """Processes a single row to generate captions and enriched columns."""
36
+ claim_image_caption = generate_caption(row["claim_image"])
37
+ evidence_image_caption = generate_caption(row["evidence_image"])
38
+
39
+ claim_enriched = f"{row['claim']}. {claim_image_caption}"
40
+ evidence_enriched = f"{row['evidence']}. {evidence_image_caption}"
41
+
42
+ return (
43
+ claim_image_caption,
44
+ evidence_image_caption,
45
+ claim_enriched,
46
+ evidence_enriched,
47
+ )
48
+
49
+
50
+ def get_last_processed_index(df: pd.DataFrame) -> int:
51
+ """
52
+ Find the last processed row index by searching backwards from the end
53
+ until finding a row where evidence_image_caption is not NA.
54
+ Returns -1 if no processed rows are found.
55
+ """
56
+ for idx in range(len(df) - 1, -1, -1):
57
+ if pd.notna(df.loc[idx, "evidence_image_caption"]):
58
+ return idx
59
+ return -1
60
+
61
+
62
+ def process_csv(input_csv: str, output_csv: str) -> None:
63
+ """Processes the CSV in chunks and writes results incrementally with efficient checkpointing."""
64
+ # Load input DataFrame
65
+ input_df = pd.read_csv(input_csv)
66
+
67
+ # Initialize or load output DataFrame
68
+ if os.path.exists(output_csv):
69
+ output_df = pd.read_csv(output_csv)
70
+ if len(output_df) != len(input_df):
71
+ print(
72
+ "Mismatch in input and output CSV lengths. Reinitializing output CSV..."
73
+ )
74
+ else:
75
+ output_df = input_df.copy()
76
+ for col in [
77
+ "claim_image_caption",
78
+ "evidence_image_caption",
79
+ "claim_enriched",
80
+ "evidence_enriched",
81
+ ]:
82
+ output_df[col] = pd.NA
83
+
84
+ # Find the last processed index
85
+ last_processed_idx = get_last_processed_index(output_df)
86
+ print(f"Resuming from index {last_processed_idx + 1}")
87
+
88
+ # Process remaining rows in batches
89
+ total_rows = len(input_df)
90
+ with tqdm(total=total_rows, initial=last_processed_idx + 1) as pbar:
91
+ for idx in range(last_processed_idx + 1, total_rows, BATCH_SIZE):
92
+ batch_end = min(idx + BATCH_SIZE, total_rows)
93
+
94
+ # Process each row in the batch
95
+ for row_idx in range(idx, batch_end):
96
+ row = input_df.iloc[row_idx]
97
+
98
+ # Skip if already processed
99
+ if pd.notna(output_df.at[row_idx, "evidence_image_caption"]):
100
+ continue
101
+
102
+ # Process the row
103
+ claim_cap, evidence_cap, claim_enr, evidence_enr = process_image_row(
104
+ row
105
+ )
106
+
107
+ # Update the output DataFrame
108
+ output_df.loc[row_idx, "claim_image_caption"] = claim_cap
109
+ output_df.loc[row_idx, "evidence_image_caption"] = evidence_cap
110
+ output_df.loc[row_idx, "claim_enriched"] = claim_enr
111
+ output_df.loc[row_idx, "evidence_enriched"] = evidence_enr
112
+
113
+ pbar.update(1)
114
+
115
+ # Save after each batch
116
+ output_df.to_csv(output_csv, index=False)
117
+ print(f"Saved progress at index {batch_end}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ for name in ["train", "test"]:
122
+ input_csv = f"{PROCESSED_DIR}/{name}.csv"
123
+ output_csv = f"{PROCESSED_DIR}/{name}_enriched.csv"
124
+
125
+ if not os.path.exists(input_csv):
126
+ raise FileNotFoundError(f"Input CSV file does not exist: {input_csv}")
127
+
128
+ process_csv(input_csv, output_csv)
129
+ print(f"Processing complete. Output saved to {output_csv}")
src/preprocess/preprocess.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from src.utils.data_utils import HEADERS
4
+ from src.utils.path_utils import get_project_root
5
+
6
+ # Constants
7
+ PROJECT_ROOT = get_project_root()
8
+ RAW_DIR = PROJECT_ROOT / "data/raw/factify"
9
+ PROCESSED_DIR = PROJECT_ROOT / "data/preprocessed"
10
+ IMAGES_DIR = RAW_DIR / "extracted/images"
11
+
12
+
13
+ def ensure_directories():
14
+ """Ensure that necessary directories exist."""
15
+ PROCESSED_DIR.mkdir(parents=True, exist_ok=True) # Create 'data/preprocessed'
16
+
17
+
18
+ def preprocess_csv(dataset: str):
19
+ """
20
+ Preprocess the given dataset CSV (train or test).
21
+
22
+ Args:
23
+ dataset (str): The dataset name ('train' or 'test').
24
+ """
25
+ # Paths
26
+ ensure_directories()
27
+
28
+ csv_path = RAW_DIR / f"extracted/{dataset}.csv"
29
+ processed_csv_path = PROCESSED_DIR / f"{dataset}.csv"
30
+ images_folder = IMAGES_DIR / dataset
31
+
32
+ if not csv_path.exists():
33
+ print(f"Dataset CSV not found: {csv_path}")
34
+ return
35
+
36
+ # Load the CSV
37
+ df = pd.read_csv(csv_path, names=HEADERS, header=None, sep="\t", skiprows=1)
38
+
39
+ # Update file paths for images
40
+ def update_image_path(row, column_name):
41
+ """Update the image path if it exists, else leave as None."""
42
+ image_file = row[column_name]
43
+ file_id = row["id"]
44
+ if column_name == "claim_image_original":
45
+ file_path = images_folder / f"{file_id}_claim.jpg"
46
+ elif column_name == "evidence_image_original":
47
+ file_path = images_folder / f"{file_id}_evidence.jpg"
48
+ else:
49
+ return None
50
+
51
+ # Check if the file exists
52
+ if file_path.exists():
53
+ # Use the relative path starting from "/data/.."
54
+ return str(file_path.relative_to(PROJECT_ROOT))
55
+ return None
56
+
57
+ df.rename(
58
+ columns={
59
+ "claim_image": "claim_image_original",
60
+ "evidence_image": "evidence_image_original",
61
+ },
62
+ inplace=True,
63
+ )
64
+ df["claim_image"] = df.apply(
65
+ lambda row: update_image_path(row, "claim_image_original"), axis=1
66
+ )
67
+ df["evidence_image"] = df.apply(
68
+ lambda row: update_image_path(row, "evidence_image_original"), axis=1
69
+ )
70
+
71
+ # Save the processed CSV
72
+ df.to_csv(processed_csv_path, index=False)
73
+ print(f"Processed {dataset}.csv saved to {processed_csv_path}")
74
+
75
+
76
+ def main():
77
+ for dataset in ["train", "test"]:
78
+ preprocess_csv(dataset)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (197 Bytes). View file
 
src/utils/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (3.24 kB). View file
 
src/utils/__pycache__/path_utils.cpython-311.pyc ADDED
Binary file (538 Bytes). View file
 
src/utils/data_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from typing import Dict, Any
5
+ from src.utils.path_utils import get_project_root
6
+
7
+ # Constants
8
+ PROJECT_ROOT = get_project_root()
9
+ PREPROCESSED_DIR = PROJECT_ROOT / "data/preprocessed"
10
+
11
+ HEADERS = [
12
+ "id",
13
+ "claim",
14
+ "claim_image",
15
+ "evidence",
16
+ "evidence_image",
17
+ "category",
18
+ "claim_ocr",
19
+ "evidence_ocr",
20
+ ]
21
+
22
+
23
+ def get_preprocessed_data(dataset: str = "train") -> pd.DataFrame:
24
+ """
25
+ Load the preprocessed data for the specified dataset.
26
+
27
+ Args:
28
+ dataset (str): Either 'train' or 'test'. Defaults to 'train'.
29
+
30
+ Returns:
31
+ pd.DataFrame: A DataFrame containing the preprocessed data.
32
+ """
33
+ csv_path = PREPROCESSED_DIR / f"{dataset}.csv"
34
+
35
+ if not csv_path.exists():
36
+ raise FileNotFoundError(f"Preprocessed dataset CSV not found: {csv_path}")
37
+
38
+ return pd.read_csv(csv_path)
39
+
40
+
41
+ def load_images_for_row(row: Dict[str, Any]) -> Dict[str, Any]:
42
+ """
43
+ Load the claim and evidence images for a given row of data.
44
+
45
+ Args:
46
+ row (Dict[str, Any]): A dictionary representing a row of preprocessed data.
47
+
48
+ Returns:
49
+ Dict[str, Any]: A dictionary containing the original row with loaded images added.
50
+ """
51
+ result = row.copy() # Copy the original row to avoid modifying the input
52
+ claim_image_path = row.get("claim_image")
53
+ evidence_image_path = row.get("evidence_image")
54
+
55
+ if claim_image_path and os.path.exists(claim_image_path):
56
+ try:
57
+ result["claim_image"] = Image.open(claim_image_path).convert("RGB")
58
+ except Exception as e:
59
+ print(f"Failed to load claim image from {claim_image_path}: {e}")
60
+ result["claim_image"] = None
61
+ else:
62
+ result["claim_image"] = None
63
+
64
+ if evidence_image_path and os.path.exists(evidence_image_path):
65
+ try:
66
+ result["evidence_image"] = Image.open(evidence_image_path).convert("RGB")
67
+ except Exception as e:
68
+ print(f"Failed to load evidence image from {evidence_image_path}: {e}")
69
+ result["evidence_image"] = None
70
+ else:
71
+ result["evidence_image"] = None
72
+
73
+ return result
src/utils/path_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+
4
+ def get_project_root() -> Path:
5
+ """Get the project root directory."""
6
+ return Path(__file__).parent.parent.parent