dennistrujillo commited on
Commit
8a6f09e
1 Parent(s): e1e6b7c

added dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +156 -0
dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import numpy as np
3
+ import h5py, torch, random, logging
4
+ from skimage.feature import peak_local_max
5
+ from skimage import measure
6
+
7
+ def clean_patch(p, center):
8
+ w, h = p.shape
9
+ cc = measure.label(p > 0)
10
+ if cc.max() == 1:
11
+ return p
12
+
13
+ # logging.warn(f"{cc.max()} peaks located in a patch")
14
+ lmin = np.inf
15
+ cc_lmin = None
16
+ for _c in range(1, cc.max()+1):
17
+ lmax = peak_local_max(p * (cc==_c), min_distance=1)
18
+ if lmax.shape[0] == 0:continue # single pixel component
19
+ lc = lmax.mean(axis=0)
20
+ dist = ((lc - center)**2).sum()
21
+ if dist < lmin:
22
+ cc_lmin = _c
23
+ lmin = dist
24
+ return p * (cc == cc_lmin)
25
+
26
+ class BraggNNDataset(Dataset):
27
+ def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
28
+ self.psz = psz
29
+ self.rnd_shift = rnd_shift
30
+
31
+ with h5py.File(pfile, "r") as h5fd:
32
+ if use == 'train':
33
+ sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
34
+ elif use == 'validation':
35
+ sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
36
+ else:
37
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
38
+
39
+ mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
40
+ mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
41
+
42
+ self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
43
+ self.peak_row = h5fd['peak_row'][sti:edi][mask]
44
+ self.peak_col = h5fd['peak_col'][sti:edi][mask]
45
+
46
+ self.fidx_base = self.peak_fidx.min()
47
+ # only loaded frames that will be used
48
+ with h5py.File(ffile, "r") as h5fd:
49
+ self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]
50
+
51
+ self.len = self.peak_fidx.shape[0]
52
+
53
+ def __getitem__(self, idx):
54
+ _frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
55
+ if self.rnd_shift > 0:
56
+ row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
57
+ col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
58
+ else:
59
+ row_shift, col_shift = 0, 0
60
+ prow_rnd = int(self.peak_row[idx]) + row_shift
61
+ pcol_rnd = int(self.peak_col[idx]) + col_shift
62
+
63
+ row_base = max(0, prow_rnd-self.psz//2)
64
+ col_base = max(0, pcol_rnd-self.psz//2 )
65
+
66
+ crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
67
+ col_base:(pcol_rnd + self.psz//2 + self.psz%2)]
68
+ # if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
69
+ if crop_img.size != self.psz ** 2:
70
+ c_pad_l = (self.psz - crop_img.shape[1]) // 2
71
+ c_pad_r = self.psz - c_pad_l - crop_img.shape[1]
72
+
73
+ r_pad_t = (self.psz - crop_img.shape[0]) // 2
74
+ r_pad_b = self.psz - r_pad_t - crop_img.shape[0]
75
+
76
+ logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
77
+ crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
78
+ else:
79
+ c_pad_l, r_pad_t = 0 ,0
80
+
81
+ _center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
82
+ crop_img = clean_patch(crop_img, _center)
83
+ if crop_img.max() != crop_img.min():
84
+ _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
85
+ feature = (crop_img - _min) / (_max - _min)
86
+ else:
87
+ logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
88
+ feature = crop_img
89
+
90
+ px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
91
+ py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
92
+
93
+ return feature[np.newaxis], np.array([px, py]).astype(np.float32)
94
+
95
+ def __len__(self):
96
+ return self.len
97
+
98
+
99
+ class PatchWiseDataset(Dataset):
100
+ def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1):
101
+ self.psz = psz
102
+ self.rnd_shift = rnd_shift
103
+ with h5py.File(pfile, "r") as h5fd:
104
+ if use == 'train':
105
+ sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
106
+ elif use == 'validation':
107
+ sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
108
+ else:
109
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
110
+
111
+ mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
112
+ mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
113
+
114
+ self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
115
+ self.peak_row = h5fd['peak_row'][sti:edi][mask]
116
+ self.peak_col = h5fd['peak_col'][sti:edi][mask]
117
+
118
+ self.fidx_base = self.peak_fidx.min()
119
+ # only loaded frames that will be used
120
+ with h5py.File(ffile, 'r') as h5fd:
121
+ if use == 'train':
122
+ sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
123
+ elif use == 'validation':
124
+ sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
125
+ else:
126
+ logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
127
+
128
+ self.crop_img = h5fd['frames'][sti:edi]
129
+ self.len = self.peak_fidx.shape[0]
130
+
131
+ def __getitem__(self, idx):
132
+ crop_img = self.crop_img[idx]
133
+
134
+ row_shift, col_shift = 0, 0
135
+ c_pad_l, r_pad_t = 0 ,0
136
+ prow_rnd = int(self.peak_row[idx]) + row_shift
137
+ pcol_rnd = int(self.peak_col[idx]) + col_shift
138
+
139
+ row_base = max(0, prow_rnd-self.psz//2)
140
+ col_base = max(0, pcol_rnd-self.psz//2)
141
+
142
+ if crop_img.max() != crop_img.min():
143
+ _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
144
+ feature = (crop_img - _min) / (_max - _min)
145
+ else:
146
+ #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
147
+ feature = crop_img
148
+
149
+ px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
150
+ py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
151
+
152
+ return feature[np.newaxis], np.array([px, py]).astype(np.float32)
153
+
154
+ def __len__(self):
155
+ return self.len
156
+