|
import os |
|
import h5py |
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
from torchvision import transforms |
|
import logging |
|
import numpy as np |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(256), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
|
|
|
|
category_to_labels = { |
|
'Support_Text': [0, 1, 1, 1], |
|
'Support_Multimodal': [0, 0, 0, 0], |
|
'Insufficient_Text': [1, 1, 1, 1], |
|
'Insufficient_Multimodal': [1, 1, 1, 0], |
|
'Refute': [2, 2, 2, 2] |
|
} |
|
|
|
def prepare_h5_dataset(csv_path, h5_path): |
|
""" |
|
Prepare h5 dataset from CSV file where each index contains complete sample data |
|
""" |
|
|
|
os.makedirs(os.path.dirname(h5_path), exist_ok=True) |
|
|
|
|
|
df = pd.read_csv(csv_path, index_col=0)[['claim', 'claim_image', 'evidence', 'evidence_image', 'category']] |
|
|
|
with h5py.File(h5_path, 'w') as f: |
|
|
|
for idx, (_, row) in enumerate(df.iterrows()): |
|
|
|
sample_group = f.create_group(str(idx)) |
|
|
|
|
|
sample_group.create_dataset('claim', data=row['claim']) |
|
sample_group.create_dataset('document', data=row['evidence']) |
|
|
|
|
|
try: |
|
claim_img = Image.open(row['claim_image']).convert('RGB') |
|
claim_img_tensor = preprocess(claim_img).numpy() |
|
except Exception as e: |
|
logger.warning(f"Error processing claim image for idx {idx}: {e}") |
|
claim_img_tensor = np.zeros((3, 256, 256), dtype='float32') |
|
sample_group.create_dataset('claim_image', data=claim_img_tensor) |
|
|
|
try: |
|
doc_img = Image.open(row['evidence_image']).convert('RGB') |
|
doc_img_tensor = preprocess(doc_img).numpy() |
|
except Exception as e: |
|
logger.warning(f"Error processing evidence image for idx {idx}: {e}") |
|
doc_img_tensor = np.zeros((3, 256, 256), dtype='float32') |
|
sample_group.create_dataset('document_image', data=doc_img_tensor) |
|
|
|
|
|
labels = category_to_labels.get(row['category'], [1, 1, 1, 1]) |
|
sample_group.create_dataset('labels', data=np.array(labels, dtype=np.int64)) |
|
|
|
logger.info(f"Created H5 dataset at {h5_path}") |
|
|
|
|
|
class MisinformationDataset(Dataset): |
|
def __init__(self, csv_path, pre_embed=False): |
|
self.csv_path = csv_path |
|
self.pre_embed = pre_embed |
|
|
|
|
|
base_path = os.path.splitext(csv_path)[0] |
|
self.h5_path = base_path + '_embeddings.h5' if pre_embed else base_path + '.h5' |
|
|
|
if not os.path.exists(self.h5_path): |
|
if pre_embed: |
|
raise FileNotFoundError(f"Pre-computed embeddings not found at {self.h5_path}. " |
|
f"Please run preprocess_embeddings.py first.") |
|
logger.info(f"H5 file not found at {self.h5_path}. Creating new H5 dataset...") |
|
prepare_h5_dataset(self.csv_path, self.h5_path) |
|
|
|
self.h5_file = h5py.File(self.h5_path, 'r') |
|
self.length = len(self.h5_file.keys()) |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
sample = self.h5_file[str(idx)] |
|
|
|
if self.pre_embed: |
|
return { |
|
'id': str(idx), |
|
'claim_text_embeds': torch.from_numpy(sample['claim_text_embeds'][()]), |
|
'doc_text_embeds': torch.from_numpy(sample['doc_text_embeds'][()]), |
|
'claim_image_embeds': torch.from_numpy(sample['claim_image_embeds'][()]), |
|
'doc_image_embeds': torch.from_numpy(sample['doc_image_embeds'][()]), |
|
'labels': torch.from_numpy(sample['labels'][()]) |
|
} |
|
else: |
|
return { |
|
'id': str(idx), |
|
'claim': sample['claim'][()].decode(), |
|
'claim_image': torch.from_numpy(sample['claim_image'][()]), |
|
'document': sample['document'][()].decode(), |
|
'document_image': torch.from_numpy(sample['document_image'][()]), |
|
'labels': torch.from_numpy(sample['labels'][()]) |
|
} |
|
|
|
def __del__(self): |
|
if hasattr(self, 'h5_file'): |
|
self.h5_file.close() |
|
|
|
|
|
def get_dataloader(csv_path, batch_size=32, num_workers=4, shuffle=False, pre_embed=False): |
|
dataset = MisinformationDataset(csv_path, pre_embed=pre_embed) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
num_workers=num_workers, |
|
pin_memory=True |
|
) |
|
|
|
return dataloader |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
train_loader = get_dataloader('data/preprocessed/train.csv', shuffle=True) |
|
|
|
|
|
|
|
for batch in train_loader: |
|
print("Train batch:") |
|
print(f"Batch size: {len(batch['id'])}") |
|
print(f"Claim shape: {batch['claim_image'].shape}") |
|
print(f"Document image shape: {batch['document_image'].shape}") |
|
print(f"Labels shape: {batch['labels'].shape}") |
|
print(f"Sample labels: {batch['labels'][0]}") |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|