File size: 995 Bytes
be2ced2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from PIL import Image
import torch
import torch.nn as nn
from typing import Dict, Iterable, Callable
from torch import Tensor
import glob
from tqdm import tqdm
import numpy as np
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None


# +
class RobustModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x, *args, **kwargs):
        return self.model(x)
    
    
class CustomArt(torch.utils.data.Dataset):
    def __init__(self, image,transforms=None):
        self.transforms   = transforms
        self.image         = image
        self.mean         = torch.tensor([0.4850, 0.4560, 0.4060])
        self.std          = torch.tensor([0.2290, 0.2240, 0.2250])
    def __getitem__(self, idx):
        if self.transforms:
            img = self.transforms(self.image)
        return torch.as_tensor(img, dtype=torch.float)
    
    def __len__(self):
        return len(self.image)