### Dataloader for fake/real image classification import torch import pandas as pd import numpy as np import os import PIL.Image import random import custom_transforms as ctrans import math import utils as ut from torchvision import transforms #from torchvision.transforms import v2 as transforms from torch.utils.data.distributed import DistributedSampler from custom_sampler import DistributedEvalSampler from functools import partial import datasets as ds import io import logging class dataset_huggingface(torch.utils.data.Dataset): """ Dataset for Community Forensics """ def __init__( self, args, repo_id='OwensLab/CommunityForensics', split='Systematic+Manual', mode='train', cache_dir='', dtype=torch.float32, ): """ args: Namespace of argument parser split: split of the dataset to use mode: 'train' or 'eval' cache_dir: directory to cache the dataset dtype: data type """ super(dataset_huggingface).__init__() self.args = args self.repo_id = repo_id self.split = split self.mode = mode self.cache_dir = cache_dir self.dtype = dtype self.dataset = self.get_hf_dataset() def __getitem__(self, index): """ Returns the image and label for the given index. """ data = self.dataset[index] image_bytes = data['image_data'] label = int(data['label']) generator_name = data['model_name'] img = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") return img, label, generator_name def get_hf_dataset(self): """ Returns the huggingface dataset object """ hf_repo_id = self.repo_id if self.mode == 'train': shuffle=True shuffle_batch_size=3000 elif self.mode == 'eval': shuffle=False #### TEST TOKEN PART #### #### TEST TOKEN PART #### #### TEST TOKEN PART #### token_df = pd.read_csv("/nfs/turbo/coe-ahowens/jespark/tokens.csv") HF_TOKEN = token_df.loc[token_df['label'] == 'huggingface_write_token', 'token'].values[0] #### TEST TOKEN PART #### #### TEST TOKEN PART #### #### TEST TOKEN PART #### hf_dataset = ds.load_dataset(hf_repo_id, split=self.split, cache_dir=self.cache_dir, token=HF_TOKEN) if shuffle: hf_dataset = hf_dataset.shuffle(seed=self.args.seed, writer_batch_size=shuffle_batch_size) return hf_dataset def __len__(self): """ Returns the length of the dataset. """ return len(self.dataset) class dataset_folder_based(torch.utils.data.Dataset): """ Dataset for sourcing images from a directory; designed to be used with the huggingface datasets library. """ def __init__( self, args, dir, labels="real:0,fake:1", logger: logging.Logger = None, dtype=torch.float32, ): """ args: Namespace of argument parser dir: directory to index labels: labels for the dataset. Default: "real:0,fake:1" -- assigns integer label 0 to images under "real" and 1 to images under "fake". dtype: data type The directory must be formatted as follows: -