misinfo / src /model /dataset.py
gyigit's picture
update
54e8a79
raw
history blame
6.56 kB
import os
import h5py
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import logging
import numpy as np
logger = logging.getLogger(__name__)
# Define preprocessing transformations
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225]),
])
# Updated category mapping for multi-label classification
# Each category maps to (text-text, text-image, image-text, image-image) labels
# 0: Support, 1: NEI (Not Enough Information), 2: Refute
category_to_labels = {
'Support_Text': [0, 1, 1, 1], # Support only for text-text
'Support_Multimodal': [0, 0, 0, 0], # Support for all paths
'Insufficient_Text': [1, 1, 1, 1], # NEI for all paths
'Insufficient_Multimodal': [1, 1, 1, 0], # Support for cross-modal paths, NEI for others
'Refute': [2, 2, 2, 2] # Refute for all paths
}
def prepare_h5_dataset(csv_path, h5_path):
"""
Prepare h5 dataset from CSV file where each index contains complete sample data
"""
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(h5_path), exist_ok=True)
# Read CSV file
df = pd.read_csv(csv_path, index_col=0)[['claim', 'claim_image', 'evidence', 'evidence_image', 'category']]
with h5py.File(h5_path, 'w') as f:
# Process each row
for idx, (_, row) in enumerate(df.iterrows()):
# Create group for this sample
sample_group = f.create_group(str(idx))
# Store text data
sample_group.create_dataset('claim', data=row['claim'])
sample_group.create_dataset('document', data=row['evidence'])
# Process and store images
try:
claim_img = Image.open(row['claim_image']).convert('RGB')
claim_img_tensor = preprocess(claim_img).numpy()
except Exception as e:
logger.warning(f"Error processing claim image for idx {idx}: {e}")
claim_img_tensor = np.zeros((3, 256, 256), dtype='float32')
sample_group.create_dataset('claim_image', data=claim_img_tensor)
try:
doc_img = Image.open(row['evidence_image']).convert('RGB')
doc_img_tensor = preprocess(doc_img).numpy()
except Exception as e:
logger.warning(f"Error processing evidence image for idx {idx}: {e}")
doc_img_tensor = np.zeros((3, 256, 256), dtype='float32')
sample_group.create_dataset('document_image', data=doc_img_tensor)
# Store multi-path labels
labels = category_to_labels.get(row['category'], [1, 1, 1, 1]) # Default to NEI if category not found
sample_group.create_dataset('labels', data=np.array(labels, dtype=np.int64))
logger.info(f"Created H5 dataset at {h5_path}")
class MisinformationDataset(Dataset):
def __init__(self, csv_path, pre_embed=False):
self.csv_path = csv_path
self.pre_embed = pre_embed
# Derive h5 path from csv path
base_path = os.path.splitext(csv_path)[0]
self.h5_path = base_path + '_embeddings.h5' if pre_embed else base_path + '.h5'
if not os.path.exists(self.h5_path):
if pre_embed:
raise FileNotFoundError(f"Pre-computed embeddings not found at {self.h5_path}. "
f"Please run preprocess_embeddings.py first.")
logger.info(f"H5 file not found at {self.h5_path}. Creating new H5 dataset...")
prepare_h5_dataset(self.csv_path, self.h5_path)
self.h5_file = h5py.File(self.h5_path, 'r')
self.length = len(self.h5_file.keys())
def __len__(self):
return self.length
def __getitem__(self, idx):
sample = self.h5_file[str(idx)]
if self.pre_embed:
return {
'id': str(idx),
'claim_text_embeds': torch.from_numpy(sample['claim_text_embeds'][()]),
'doc_text_embeds': torch.from_numpy(sample['doc_text_embeds'][()]),
'claim_image_embeds': torch.from_numpy(sample['claim_image_embeds'][()]),
'doc_image_embeds': torch.from_numpy(sample['doc_image_embeds'][()]),
'labels': torch.from_numpy(sample['labels'][()])
}
else:
return {
'id': str(idx),
'claim': sample['claim'][()].decode(),
'claim_image': torch.from_numpy(sample['claim_image'][()]),
'document': sample['document'][()].decode(),
'document_image': torch.from_numpy(sample['document_image'][()]),
'labels': torch.from_numpy(sample['labels'][()])
}
def __del__(self):
if hasattr(self, 'h5_file'):
self.h5_file.close()
def get_dataloader(csv_path, batch_size=32, num_workers=4, shuffle=False, pre_embed=False):
dataset = MisinformationDataset(csv_path, pre_embed=pre_embed)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True
)
return dataloader
if __name__ == "__main__":
# Set up logging
logging.basicConfig(level=logging.INFO)
# Create dataloaders
train_loader = get_dataloader('data/preprocessed/train.csv', shuffle=True)
#test_loader = get_dataloader('data/preprocessed/test.csv', shuffle=False)
# Test dataloaders
for batch in train_loader:
print("Train batch:")
print(f"Batch size: {len(batch['id'])}")
print(f"Claim shape: {batch['claim_image'].shape}")
print(f"Document image shape: {batch['document_image'].shape}")
print(f"Labels shape: {batch['labels'].shape}") # Should be (batch_size, 4)
print(f"Sample labels: {batch['labels'][0]}") # Show labels for first item
break
#for batch in test_loader:
# print("\nTest batch:")
# print(f"Batch size: {len(batch['id'])}")
# print(f"Claim shape: {batch['claim_image'].shape}")
# print(f"Document image shape: {batch['document_image'].shape}")
# print(f"Labels shape: {batch['labels'].shape}")
# print(f"Sample labels: {batch['labels'][0]}")
# break