saptak21 commited on
Commit
87dd991
·
verified ·
1 Parent(s): 5afaac3

Upload 7 files

Browse files
datasets/__init__.py ADDED
File without changes
datasets/eyediap.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import h5py
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ from typing import List
7
+ from omegaconf import OmegaConf, listconfig
8
+ from .helper.image_transform import wrap_transforms
9
+
10
+
11
+ class EYEDIAPDataset(Dataset):
12
+ def __init__(self,
13
+ dataset_path: str,
14
+ color_type,
15
+ keys_to_use: List[str] = None,
16
+ data_name=None,
17
+ image_size:int=224, ## <---
18
+ transform_type='basic_imagenet', ## <--- modified
19
+ image_key='face_patch',
20
+ gaze_key='face_gaze',
21
+ ):
22
+
23
+ self.path = dataset_path
24
+ self.hdfs = {}
25
+ self.data_name = data_name
26
+ self.image_key = image_key
27
+ self.gaze_key = gaze_key
28
+
29
+ self.image_size = (image_size, image_size)
30
+
31
+ assert color_type in ['rgb', 'bgr']
32
+ self.color_type = color_type
33
+ self.selected_keys = [k for k in keys_to_use]
34
+ assert len(self.selected_keys) > 0
35
+
36
+ self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys]
37
+
38
+ for num_i in range(0, len(self.selected_keys)):
39
+ file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
40
+ self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
41
+ print('read file: ', os.path.join(self.path, self.selected_keys[num_i]))
42
+ assert self.hdfs[num_i].swmr_mode
43
+
44
+ self.build_idx_to_kv()
45
+
46
+ for num_i in range(0, len(self.hdfs)):
47
+ if self.hdfs[num_i]:
48
+ self.hdfs[num_i].close()
49
+ self.hdfs[num_i] = None
50
+ self.transform = wrap_transforms(transform_type, image_size=image_size)
51
+ self.__hdfs = None
52
+ self.hdf = None
53
+
54
+ def __len__(self):
55
+ return len(self.idx_to_kv)
56
+
57
+ def __del__(self):
58
+ for num_i in range(0, len(self.hdfs)):
59
+ if self.hdfs[num_i]:
60
+ self.hdfs[num_i].close()
61
+ self.hdfs[num_i] = None
62
+
63
+ def build_idx_to_kv(self):
64
+ self.idx_to_kv = []
65
+ self.key_idx_dict = {}
66
+ for num_i in range(0, len(self.selected_keys)):
67
+ this_sub = self.selected_keys[num_i].split('.')[0]
68
+ n = self.hdfs[num_i][self.image_key].shape[0]
69
+ self.idx_to_kv += [(num_i, i) for i in range(n)]
70
+ self.key_idx_dict[this_sub] = [ i for i in range(n)]
71
+
72
+ @property
73
+ def archives(self):
74
+ if self.__hdfs is None: # lazy loading here!
75
+ self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
76
+ return self.__hdfs
77
+
78
+
79
+ def preprocess_image(self, image):
80
+ image = image.astype(np.float32)
81
+ if self.color_type == 'bgr':
82
+ image = image[..., ::-1]
83
+ image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
84
+ image = self.transform(image.astype(np.uint8) )
85
+ return image
86
+
87
+ def __getitem__(self, index):
88
+ key, idx = self.idx_to_kv[index]
89
+ self.hdf = self.archives[key]
90
+ assert self.hdf.swmr_mode
91
+
92
+ image = self.hdf[self.image_key][idx, :]
93
+ gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
94
+ head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
95
+
96
+ entry = {
97
+ 'image': self.preprocess_image(image),
98
+ 'gaze': gaze_label,
99
+ 'head': head_label,
100
+ 'key': key,
101
+ 'index':index
102
+ }
103
+ return entry
datasets/gaze360.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import h5py, cv2
4
+ from torch.utils.data import Dataset
5
+ from typing import List
6
+ from .helper.image_transform import wrap_transforms
7
+
8
+
9
+ class Gaze360Dataset(Dataset):
10
+ def __init__(self,
11
+ dataset_path: str,
12
+ color_type,
13
+ keys_to_use: List[str] = None,
14
+ data_name=None,
15
+ image_size:int=224,
16
+ transform_type='basic_imagenet',
17
+ image_key='face_patch',
18
+ gaze_key='face_gaze',
19
+ sample_rate_use=1,
20
+ ):
21
+ super().__init__()
22
+ self.dataset_path = dataset_path
23
+ self.hdfs = {}
24
+ self.data_name = data_name
25
+ self.image_key = image_key
26
+ self.gaze_key = gaze_key
27
+ self.image_size = (image_size, image_size)
28
+
29
+ assert color_type in ['rgb', 'bgr']
30
+ self.color_type = color_type
31
+ self.transform = wrap_transforms(transform_type, image_size=image_size)
32
+
33
+ self.sample_rate_use = sample_rate_use
34
+ #### -------------------------------------------------------- read the h5 files -------------------------------------------------------
35
+ self.selected_keys = [k for k in keys_to_use]
36
+ assert len(self.selected_keys) > 0
37
+ self.file_paths = [os.path.join(self.dataset_path, k) for k in self.selected_keys]
38
+ for num_i in range(0, len(self.selected_keys)):
39
+ file_path = os.path.join(self.dataset_path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
40
+ self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
41
+ print('read file: ', os.path.join(self.dataset_path, self.selected_keys[num_i]))
42
+ assert self.hdfs[num_i].swmr_mode
43
+ ####-----------------------------------------------------------------------------------------------------------------------------------
44
+
45
+ self.build_idx_to_kv()
46
+ for num_i in range(0, len(self.hdfs)):
47
+ if self.hdfs[num_i]:
48
+ self.hdfs[num_i].close()
49
+ self.hdfs[num_i] = None
50
+
51
+ self.__hdfs = None
52
+ self.hdf = None
53
+
54
+ def build_idx_to_kv(self):
55
+ self.idx_to_kv = []
56
+ self.key_idx_dict = {}
57
+ for num_i in range(0, len(self.selected_keys)):
58
+ p_key = self.selected_keys[num_i].split('.')[0] ##p00
59
+ n = self.hdfs[num_i][self.image_key].shape[0]
60
+ if self.sample_rate_use > 1:
61
+ indices = np.arange(0, n, self.sample_rate_use)
62
+ else:
63
+ indices = np.arange(0, n)
64
+ self.idx_to_kv += [(num_i, i) for i in indices]
65
+ self.key_idx_dict[p_key] = [i for i in indices]
66
+
67
+
68
+ def __len__(self):
69
+ return len(self.idx_to_kv)
70
+
71
+ def __del__(self):
72
+ for num_i in range(0, len(self.hdfs)):
73
+ if self.hdfs[num_i]:
74
+ self.hdfs[num_i].close()
75
+ self.hdfs[num_i] = None
76
+
77
+ @property
78
+ def archives(self):
79
+ if self.__hdfs is None: # lazy loading here!
80
+ self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
81
+ return self.__hdfs
82
+
83
+ def preprocess_image(self, image):
84
+ image = image.astype(np.float32)
85
+ if self.color_type == 'bgr':
86
+ image = image[..., ::-1]
87
+ if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]:
88
+ image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
89
+ image = self.transform(image.astype(np.uint8) )
90
+ return image
91
+
92
+ def __getitem__(self, index):
93
+ key, idx = self.idx_to_kv[index]
94
+ self.hdf = self.archives[key]
95
+ image = self.hdf[self.image_key][idx]
96
+ gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
97
+ head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
98
+ entry = {
99
+ 'image': self.preprocess_image(image),
100
+ 'gaze': gaze_label,
101
+ 'head': head_label,
102
+ 'key': idx,
103
+ 'index':index
104
+ }
105
+ return entry
106
+
datasets/gazecapture.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import h5py
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ from typing import List
7
+ from omegaconf import OmegaConf, listconfig
8
+ from .helper.image_transform import wrap_transforms
9
+
10
+
11
+ class GazeCaptureDataset(Dataset):
12
+ def __init__(self,
13
+ dataset_path: str,
14
+ color_type,
15
+ keys_to_use: List[str] = None,
16
+ data_name=None,
17
+ image_size:int=224, ## <---
18
+ transform_type='basic_imagenet', ## <--- modified
19
+ image_key='face_patch',
20
+ gaze_key='face_gaze',
21
+ sample_rate_use=1,
22
+ ):
23
+
24
+ self.transform = wrap_transforms(transform_type, image_size=image_size)
25
+
26
+ self.path = dataset_path
27
+ self.hdfs = {}
28
+ self.data_name = data_name
29
+ self.image_key = image_key
30
+ self.gaze_key = gaze_key
31
+
32
+ self.image_size = (image_size, image_size)
33
+
34
+ self.sample_rate_use = sample_rate_use
35
+
36
+ assert color_type in ['rgb', 'bgr']
37
+ self.color_type = color_type
38
+ self.selected_keys = [ k for k in keys_to_use]
39
+ assert len(self.selected_keys) > 0
40
+
41
+ self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys]
42
+ for num_i in range(0, len(self.selected_keys)):
43
+ file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
44
+ self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
45
+ print('read file: ', os.path.join(self.path, self.selected_keys[num_i]))
46
+ assert self.hdfs[num_i].swmr_mode
47
+
48
+
49
+ self.build_idx_to_kv()
50
+
51
+
52
+ for num_i in range(0, len(self.hdfs)):
53
+ if self.hdfs[num_i]:
54
+ self.hdfs[num_i].close()
55
+ self.hdfs[num_i] = None
56
+
57
+ self.__hdfs = None
58
+ self.hdf = None
59
+
60
+ def __len__(self):
61
+ return len(self.idx_to_kv)
62
+
63
+ def __del__(self):
64
+ for num_i in range(0, len(self.hdfs)):
65
+ if self.hdfs[num_i]:
66
+ self.hdfs[num_i].close()
67
+ self.hdfs[num_i] = None
68
+
69
+ def build_idx_to_kv(self):
70
+ self.idx_to_kv = []
71
+ self.key_idx_dict = {}
72
+ for num_i in range(0, len(self.selected_keys)):
73
+ this_sub = self.selected_keys[num_i].split('.')[0]
74
+ n = self.hdfs[num_i][self.image_key].shape[0]
75
+ if self.sample_rate_use > 1:
76
+ indices = np.arange(0, n, self.sample_rate_use)
77
+ else:
78
+ indices = np.arange(0, n)
79
+ self.idx_to_kv += [(num_i, i) for i in indices ]
80
+ self.key_idx_dict[this_sub] = [ i for i in indices ]
81
+
82
+ @property
83
+ def archives(self):
84
+ if self.__hdfs is None: # lazy loading here!
85
+ self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
86
+ return self.__hdfs
87
+
88
+
89
+ def preprocess_image(self, image):
90
+ image = image.astype(np.float32)
91
+ if self.color_type == 'bgr':
92
+ image = image[..., ::-1]
93
+ image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
94
+ image = self.transform(image.astype(np.uint8) )
95
+ return image
96
+
97
+ def __getitem__(self, index):
98
+
99
+ key, idx = self.idx_to_kv[index]
100
+ self.hdf = self.archives[key]
101
+
102
+ # self.hdf = h5py.File(os.path.join(self.path, self.selected_keys[key]), 'r', swmr=True)
103
+ assert self.hdf.swmr_mode
104
+
105
+ image = self.hdf[self.image_key][idx, :]
106
+ gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
107
+ head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
108
+
109
+ entry = {
110
+ 'image': self.preprocess_image(image),
111
+ 'gaze': gaze_label,
112
+ 'head': head_label,
113
+ 'key': key,
114
+ 'index':index
115
+ }
116
+ return entry
117
+
118
+ # class GazeCaptureDatasetSubset(GazeCaptureDataset):
119
+ # def __init__(self, images_per_person=None, **kwargs):
120
+ # self.images_per_person = images_per_person
121
+ # super().__init__(**kwargs)
122
+
123
+ # def build_idx_to_kv(self):
124
+ # self.idx_to_kv = []
125
+ # self.key_idx_dict = {}
126
+ # for num_i in range(0, len(self.selected_keys)):
127
+ # this_sub = self.selected_keys[num_i].split('.')[0]
128
+ # n = self.hdfs[num_i][self.image_key].shape[0]
129
+ # if self.images_per_person is not None:
130
+ # n = min(n, self.images_per_person)
131
+ # self.idx_to_kv += [(num_i, i) for i in range(n)]
132
+ # self.key_idx_dict[this_sub] = [ i for i in range(n)]
datasets/helper/image_transform.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ import torch
6
+
7
+ def re_normalize(image_tensor, old='[-1,1]', new='imagenet'):
8
+ """
9
+ Re-normalizes an image tensor from one normalization scheme to another.
10
+ Args:
11
+ image_tensor (torch.Tensor): Image tensor to be re-normalized.
12
+ old (str): Old normalization scheme. Options: '[-1,1]', 'imagenet'.
13
+ new (str): New normalization scheme. Options: '[-1,1]', 'imagenet'.
14
+ Returns:
15
+ torch.Tensor: Re-normalized image tensor.
16
+ """
17
+ # Old normalization parameters
18
+ device = image_tensor.device
19
+ if old == '[-1,1]':
20
+ old_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
21
+ old_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
22
+ elif old == 'imagenet':
23
+ old_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
24
+ old_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
25
+ elif old == '[0,1]':
26
+ old_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device)
27
+ old_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device)
28
+ else:
29
+ print('old normalization not implemented')
30
+ raise NotImplementedError
31
+ # New normalization parameters
32
+ if new == '[-1,1]':
33
+ new_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
34
+ new_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
35
+ elif new == 'imagenet':
36
+ new_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
37
+ new_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
38
+ elif new == '[0,1]':
39
+ new_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device)
40
+ new_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device)
41
+ else:
42
+ print('new normalization not implemented')
43
+ raise NotImplementedError
44
+ # Step 1: Denormalize the image tensor using the old mean and std
45
+ denormalized_image = image_tensor * old_std + old_mean
46
+ # Step 2: Normalize the image tensor using the new mean and std
47
+ normalized_image = (denormalized_image - new_mean) / new_std
48
+
49
+ return normalized_image
50
+
51
+
52
+
53
+
54
+
55
+
56
+ def wrap_transforms(image_transforms_type, image_size):
57
+
58
+
59
+ if image_transforms_type == 'basic_imagenet':
60
+ MEAN = [0.485, 0.456, 0.406]
61
+ STD = [0.229, 0.224, 0.225]
62
+ return transforms.Compose([
63
+ transforms.ToPILImage(),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=MEAN, std=STD)
66
+ ])
67
+
68
+
69
+ else:
70
+ raise NotImplementedError
71
+
72
+
73
+
74
+ # def enhance_contrast_clahe(image):
75
+ # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
76
+ # lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
77
+ # lab_planes = list( cv2.split(lab) )
78
+ # lab_planes[0] = clahe.apply(lab_planes[0])
79
+ # lab = cv2.merge(lab_planes)
80
+ # image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
81
+ # return image
datasets/mpiigaze.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import h5py
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ from typing import List
7
+ from omegaconf import OmegaConf, listconfig
8
+ from .helper.image_transform import wrap_transforms
9
+
10
+
11
+ class MPIIGazeDataset(Dataset):
12
+ def __init__(self,
13
+ dataset_path: str,
14
+ color_type,
15
+ keys_to_use: List[str] = None,
16
+ data_name=None,
17
+ image_size:int=224, ## <---
18
+ transform_type='basic_imagenet', ## <--- modified
19
+ image_key='face_patch',
20
+ gaze_key='face_gaze',
21
+ ):
22
+
23
+ self.dataset_path = dataset_path
24
+ self.hdfs = {}
25
+ self.data_name = data_name
26
+ self.image_key = image_key
27
+ self.gaze_key = gaze_key
28
+ self.image_size = (image_size, image_size)
29
+
30
+ assert color_type in ['rgb', 'bgr']
31
+ self.color_type = color_type
32
+ self.transform = wrap_transforms(transform_type, image_size=image_size)
33
+
34
+
35
+ self.selected_keys = [k for k in keys_to_use]
36
+ assert len(self.selected_keys) > 0
37
+
38
+ self.file_paths = [os.path.join(self.dataset_path, k) for k in self.selected_keys]
39
+
40
+ for num_i in range(0, len(self.selected_keys)):
41
+ file_path = os.path.join(self.dataset_path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
42
+ self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
43
+ print('read file: ', os.path.join(self.dataset_path, self.selected_keys[num_i]))
44
+ assert self.hdfs[num_i].swmr_mode
45
+
46
+ self.build_idx_to_kv()
47
+
48
+ for num_i in range(0, len(self.hdfs)):
49
+ if self.hdfs[num_i]:
50
+ self.hdfs[num_i].close()
51
+ self.hdfs[num_i] = None
52
+
53
+
54
+
55
+ self.__hdfs = None
56
+ self.hdf = None
57
+
58
+ def __len__(self):
59
+ return len(self.idx_to_kv)
60
+
61
+ def __del__(self):
62
+ for num_i in range(0, len(self.hdfs)):
63
+ if self.hdfs[num_i]:
64
+ self.hdfs[num_i].close()
65
+ self.hdfs[num_i] = None
66
+
67
+ def build_idx_to_kv(self):
68
+
69
+ self.idx_to_kv = []
70
+ self.key_idx_dict = {}
71
+ for num_i in range(0, len(self.selected_keys)):
72
+ p_key = self.selected_keys[num_i].split('.')[0] ##p00
73
+ n = self.hdfs[num_i][self.image_key].shape[0]
74
+ self.idx_to_kv += [(num_i, i) for i in range(n)]
75
+ self.key_idx_dict[p_key] = [i for i in range(n)]
76
+ @property
77
+ def archives(self):
78
+ if self.__hdfs is None: # lazy loading here!
79
+ self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
80
+ return self.__hdfs
81
+
82
+
83
+ def preprocess_image(self, image):
84
+ image = image.astype(np.float32)
85
+ if self.color_type == 'bgr':
86
+ image = image[..., ::-1]
87
+ if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]:
88
+ image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
89
+ image = self.transform(image.astype(np.uint8) )
90
+ return image
91
+
92
+ def __getitem__(self, index):
93
+ key, idx = self.idx_to_kv[index]
94
+ self.hdf = self.archives[key]
95
+ # self.hdf = h5py.File(os.path.join(self.dataset_path, self.selected_keys[key]), 'r', swmr=True)
96
+ assert self.hdf.swmr_mode
97
+ image = self.hdf[self.image_key][idx, :]
98
+ gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
99
+ head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
100
+ entry = {
101
+ 'image': self.preprocess_image(image),
102
+ 'gaze': gaze_label,
103
+ 'head': head_label,
104
+ 'key': key,
105
+ 'index':index
106
+ }
107
+
108
+ return entry
109
+
datasets/xgaze.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,random
2
+ import numpy as np
3
+ import h5py
4
+ import cv2
5
+ from typing import List
6
+ from torch.utils.data import Dataset
7
+ from .helper.image_transform import wrap_transforms
8
+
9
+ class XGazeDataset(Dataset):
10
+ def __init__(self,
11
+ dataset_path: str,
12
+ color_type,
13
+ images_per_frame,
14
+ keys_to_use: List[str] = None,
15
+ data_name=None,
16
+ image_size:int=224,
17
+ transform_type='basic_imagenet', ## <--- modified
18
+ image_key='face_patch',
19
+ gaze_key='face_gaze',
20
+ camera_random=None,
21
+ frame_tag=[0,1000],
22
+ seed=0,
23
+ ):
24
+
25
+ self.path = dataset_path
26
+ self.hdfs = {}
27
+ self.data_name = data_name
28
+ self.images_per_frame = images_per_frame
29
+
30
+ print('images_per_frame: ', images_per_frame)
31
+ self.image_key = image_key
32
+ self.gaze_key = gaze_key
33
+ self.image_size = (image_size, image_size)
34
+ random.seed(seed)
35
+
36
+ assert color_type in ['rgb', 'bgr']
37
+ self.color_type = color_type
38
+ self.cameras_idx = list(range(self.images_per_frame))
39
+ self.camera_random = camera_random
40
+
41
+ #### -------------------------------------------------------- read the h5 files -------------------------------------------------------
42
+ self.selected_keys = [k for k in keys_to_use]
43
+ assert len(self.selected_keys) > 0
44
+ self.file_paths = [os.path.join(self.path, k) for k in self.selected_keys]
45
+ for num_i in range(0, len(self.selected_keys)):
46
+ file_path = os.path.join(self.path, self.selected_keys[num_i]) # the subdirectories: train, test are not used in MPIIFaceGaze and MPII_Rotate
47
+ self.hdfs[num_i] = h5py.File(file_path, 'r', swmr=True)
48
+ print('read file: ', os.path.join(self.path, self.selected_keys[num_i]))
49
+ assert self.hdfs[num_i].swmr_mode
50
+ ####-----------------------------------------------------------------------------------------------------------------------------------
51
+
52
+
53
+ self.idx_to_kv = []
54
+ self.key_idx_dict = {} ## this is for reading the second sample from the same person
55
+ for num_i in range(0, len(self.selected_keys)):
56
+ this_sub = self.selected_keys[num_i].split('.')[0]
57
+ n = self.hdfs[num_i][image_key].shape[0]
58
+
59
+ if type(frame_tag) == list:
60
+ self.start_frame, self.end_frame = frame_tag
61
+ elif frame_tag == 'all':
62
+ self.start_frame, self.end_frame = 0, 10000
63
+ else:
64
+ raise ValueError("frame_tag should be either a list of integers or str 'all' ")
65
+ start_idx = min(n, self.start_frame * self.images_per_frame)
66
+ end_idx = min(n, self.end_frame * self.images_per_frame)
67
+
68
+ if self.camera_random is None:
69
+ self.idx_to_kv += [(num_i, i) for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ]
70
+ self.key_idx_dict[this_sub] = [ i for i in range(start_idx, end_idx) if (i % self.images_per_frame ) in self.cameras_idx ]
71
+ else:
72
+ for frame in range(start_idx // self.images_per_frame, end_idx // self.images_per_frame):
73
+ frame_start_idx = frame * self.images_per_frame
74
+ frame_end_idx = frame_start_idx + self.images_per_frame
75
+
76
+ # Randomly select self.images_per_frame camera indices for this frame
77
+ random_cameras_idx = random.sample(range(self.images_per_frame), self.camera_random)
78
+ self.idx_to_kv += [(num_i, i) for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx]
79
+ self.key_idx_dict.setdefault(this_sub, []).extend(
80
+ [i for i in range(frame_start_idx, frame_end_idx) if (i % self.images_per_frame) in random_cameras_idx]
81
+ )
82
+
83
+ for num_i in range(0, len(self.hdfs)):
84
+ if self.hdfs[num_i]:
85
+ self.hdfs[num_i].close()
86
+ self.hdfs[num_i] = None
87
+
88
+ self.transform = wrap_transforms(transform_type, image_size=image_size)
89
+ self.__hdfs = None
90
+ self.hdf = None
91
+
92
+
93
+ def __len__(self):
94
+ return len(self.idx_to_kv)
95
+
96
+ def __del__(self):
97
+ for num_i in range(0, len(self.hdfs)):
98
+ if self.hdfs[num_i]:
99
+ self.hdfs[num_i].close()
100
+ self.hdfs[num_i] = None
101
+
102
+
103
+ @property
104
+ def archives(self):
105
+ if self.__hdfs is None: # lazy loading here!
106
+ self.__hdfs = [h5py.File(h5_path, "r", swmr=True) for h5_path in self.file_paths]
107
+ return self.__hdfs
108
+
109
+ def preprocess_image(self, image):
110
+ image = image.astype(np.float32)
111
+ if self.color_type == 'bgr':
112
+ image = image[..., ::-1]
113
+ if image.shape[0] != self.image_size[0] or image.shape[1] != self.image_size[1]:
114
+ image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_AREA)
115
+
116
+ image = self.transform( image.astype(np.uint8) )
117
+ return image
118
+
119
+ def __getitem__(self, index):
120
+ key, idx = self.idx_to_kv[index]
121
+ self.hdf = self.archives[key]
122
+ assert self.hdf.swmr_mode
123
+ image = self.hdf[self.image_key][idx, :]
124
+ gaze_label = self.hdf[self.gaze_key][idx].astype('float') if self.gaze_key in self.hdf else np.array([0,0]).astype('float')
125
+ head_label = self.hdf['face_head_pose'][idx].astype('float') if 'face_head_pose' in self.hdf else np.array([0,0]).astype('float')
126
+
127
+ entry = {
128
+ 'image': self.preprocess_image(image),
129
+ 'gaze': gaze_label,
130
+ 'head': head_label,
131
+ 'key': key,
132
+ 'index':index
133
+ }
134
+
135
+ return entry
136
+
137
+