File size: 825 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision import transforms
import numpy as np
from PIL import Image
import cv2

from spiga.data.loaders.transforms import TargetCrop, ToOpencv, AddModel3D


def get_transformers(data_config):
    transformer_seq = [
        Opencv2Pil(),
        TargetCrop(data_config.image_size, data_config.target_dist),
        ToOpencv(),
        NormalizeAndPermute()]
    return transforms.Compose(transformer_seq)


class NormalizeAndPermute:
    def __call__(self, sample):
        image = np.array(sample['image'], dtype=float)
        image = np.transpose(image, (2, 0, 1))
        sample['image'] = image / 255
        return sample


class Opencv2Pil:
    def __call__(self, sample):
        image = cv2.cvtColor(sample['image'], cv2.COLOR_BGR2RGB)
        sample['image'] = Image.fromarray(image)
        return sample