from PIL import Image import torch import torch.nn as nn from typing import Dict, Iterable, Callable from torch import Tensor import glob from tqdm import tqdm import numpy as np from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = None # + class RobustModel(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, x, *args, **kwargs): return self.model(x) class CustomArt(torch.utils.data.Dataset): def __init__(self, image,transforms=None): self.transforms = transforms self.image = image self.mean = torch.tensor([0.4850, 0.4560, 0.4060]) self.std = torch.tensor([0.2290, 0.2240, 0.2250]) def __getitem__(self, idx): if self.transforms: img = self.transforms(self.image) return torch.as_tensor(img, dtype=torch.float) def __len__(self): return len(self.image)