| import os |
| import cv2 |
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| from src.config import Config |
|
|
| class DeepfakeDataset(Dataset): |
| def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None): |
| """ |
| Args: |
| root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided) |
| file_paths (list): List of absolute paths to images. |
| labels (list): List of labels corresponding to file_paths. |
| phase (str): 'train' or 'val'. |
| max_samples (int): Optional limit for quick debugging. |
| """ |
| self.phase = phase |
| |
| if file_paths is not None and labels is not None: |
| self.image_paths = file_paths |
| self.labels = labels |
| elif root_dir is not None: |
| self.image_paths, self.labels = self.scan_directory(root_dir) |
| else: |
| raise ValueError("Either root_dir or (file_paths, labels) must be provided.") |
| |
| if max_samples: |
| self.image_paths = self.image_paths[:max_samples] |
| self.labels = self.labels[:max_samples] |
| |
| self.transform = self._get_transforms() |
| |
| print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.") |
|
|
| @staticmethod |
| def scan_directory(root_dir): |
| image_paths = [] |
| labels = [] |
| print(f"Scanning dataset at {root_dir}...") |
| |
| |
| exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif') |
| |
| for root, dirs, files in os.walk(root_dir): |
| for file in files: |
| if file.lower().endswith(exts): |
| path = os.path.join(root, file) |
| |
| path_lower = path.lower() |
| |
| label = None |
| |
| if "real" in path_lower: |
| label = 0.0 |
| elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]): |
| label = 1.0 |
| |
| if label is not None: |
| image_paths.append(path) |
| labels.append(label) |
| |
| return image_paths, labels |
|
|
| def _get_transforms(self): |
| size = Config.IMAGE_SIZE |
| if self.phase == 'train': |
| return A.Compose([ |
| A.Resize(size, size), |
| A.HorizontalFlip(p=0.5), |
| A.RandomBrightnessContrast(p=0.2), |
| A.GaussNoise(p=0.2), |
| |
| |
| A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
| else: |
| return A.Compose([ |
| A.Resize(size, size), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| path = self.image_paths[idx] |
| label = self.labels[idx] |
| |
| try: |
| image = cv2.imread(path) |
| if image is None: |
| raise ValueError("Image not found or corrupt") |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| except Exception as e: |
| |
| |
| return self.__getitem__((idx + 1) % len(self)) |
| |
| if self.transform: |
| augmented = self.transform(image=image) |
| image = augmented['image'] |
| |
| return image, torch.tensor(label, dtype=torch.float32) |
|
|