dennistrujillo
commited on
Commit
•
8a6f09e
1
Parent(s):
e1e6b7c
added dataset.py
Browse files- dataset.py +156 -0
dataset.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import numpy as np
|
3 |
+
import h5py, torch, random, logging
|
4 |
+
from skimage.feature import peak_local_max
|
5 |
+
from skimage import measure
|
6 |
+
|
7 |
+
def clean_patch(p, center):
|
8 |
+
w, h = p.shape
|
9 |
+
cc = measure.label(p > 0)
|
10 |
+
if cc.max() == 1:
|
11 |
+
return p
|
12 |
+
|
13 |
+
# logging.warn(f"{cc.max()} peaks located in a patch")
|
14 |
+
lmin = np.inf
|
15 |
+
cc_lmin = None
|
16 |
+
for _c in range(1, cc.max()+1):
|
17 |
+
lmax = peak_local_max(p * (cc==_c), min_distance=1)
|
18 |
+
if lmax.shape[0] == 0:continue # single pixel component
|
19 |
+
lc = lmax.mean(axis=0)
|
20 |
+
dist = ((lc - center)**2).sum()
|
21 |
+
if dist < lmin:
|
22 |
+
cc_lmin = _c
|
23 |
+
lmin = dist
|
24 |
+
return p * (cc == cc_lmin)
|
25 |
+
|
26 |
+
class BraggNNDataset(Dataset):
|
27 |
+
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
|
28 |
+
self.psz = psz
|
29 |
+
self.rnd_shift = rnd_shift
|
30 |
+
|
31 |
+
with h5py.File(pfile, "r") as h5fd:
|
32 |
+
if use == 'train':
|
33 |
+
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
|
34 |
+
elif use == 'validation':
|
35 |
+
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
|
36 |
+
else:
|
37 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
38 |
+
|
39 |
+
mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
|
40 |
+
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
|
41 |
+
|
42 |
+
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
|
43 |
+
self.peak_row = h5fd['peak_row'][sti:edi][mask]
|
44 |
+
self.peak_col = h5fd['peak_col'][sti:edi][mask]
|
45 |
+
|
46 |
+
self.fidx_base = self.peak_fidx.min()
|
47 |
+
# only loaded frames that will be used
|
48 |
+
with h5py.File(ffile, "r") as h5fd:
|
49 |
+
self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]
|
50 |
+
|
51 |
+
self.len = self.peak_fidx.shape[0]
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
_frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
|
55 |
+
if self.rnd_shift > 0:
|
56 |
+
row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
|
57 |
+
col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
|
58 |
+
else:
|
59 |
+
row_shift, col_shift = 0, 0
|
60 |
+
prow_rnd = int(self.peak_row[idx]) + row_shift
|
61 |
+
pcol_rnd = int(self.peak_col[idx]) + col_shift
|
62 |
+
|
63 |
+
row_base = max(0, prow_rnd-self.psz//2)
|
64 |
+
col_base = max(0, pcol_rnd-self.psz//2 )
|
65 |
+
|
66 |
+
crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
|
67 |
+
col_base:(pcol_rnd + self.psz//2 + self.psz%2)]
|
68 |
+
# if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
|
69 |
+
if crop_img.size != self.psz ** 2:
|
70 |
+
c_pad_l = (self.psz - crop_img.shape[1]) // 2
|
71 |
+
c_pad_r = self.psz - c_pad_l - crop_img.shape[1]
|
72 |
+
|
73 |
+
r_pad_t = (self.psz - crop_img.shape[0]) // 2
|
74 |
+
r_pad_b = self.psz - r_pad_t - crop_img.shape[0]
|
75 |
+
|
76 |
+
logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
|
77 |
+
crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
|
78 |
+
else:
|
79 |
+
c_pad_l, r_pad_t = 0 ,0
|
80 |
+
|
81 |
+
_center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
|
82 |
+
crop_img = clean_patch(crop_img, _center)
|
83 |
+
if crop_img.max() != crop_img.min():
|
84 |
+
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
|
85 |
+
feature = (crop_img - _min) / (_max - _min)
|
86 |
+
else:
|
87 |
+
logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
|
88 |
+
feature = crop_img
|
89 |
+
|
90 |
+
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
|
91 |
+
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
|
92 |
+
|
93 |
+
return feature[np.newaxis], np.array([px, py]).astype(np.float32)
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return self.len
|
97 |
+
|
98 |
+
|
99 |
+
class PatchWiseDataset(Dataset):
|
100 |
+
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1):
|
101 |
+
self.psz = psz
|
102 |
+
self.rnd_shift = rnd_shift
|
103 |
+
with h5py.File(pfile, "r") as h5fd:
|
104 |
+
if use == 'train':
|
105 |
+
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
|
106 |
+
elif use == 'validation':
|
107 |
+
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
|
108 |
+
else:
|
109 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
110 |
+
|
111 |
+
mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
|
112 |
+
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))
|
113 |
+
|
114 |
+
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
|
115 |
+
self.peak_row = h5fd['peak_row'][sti:edi][mask]
|
116 |
+
self.peak_col = h5fd['peak_col'][sti:edi][mask]
|
117 |
+
|
118 |
+
self.fidx_base = self.peak_fidx.min()
|
119 |
+
# only loaded frames that will be used
|
120 |
+
with h5py.File(ffile, 'r') as h5fd:
|
121 |
+
if use == 'train':
|
122 |
+
sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
|
123 |
+
elif use == 'validation':
|
124 |
+
sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
|
125 |
+
else:
|
126 |
+
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")
|
127 |
+
|
128 |
+
self.crop_img = h5fd['frames'][sti:edi]
|
129 |
+
self.len = self.peak_fidx.shape[0]
|
130 |
+
|
131 |
+
def __getitem__(self, idx):
|
132 |
+
crop_img = self.crop_img[idx]
|
133 |
+
|
134 |
+
row_shift, col_shift = 0, 0
|
135 |
+
c_pad_l, r_pad_t = 0 ,0
|
136 |
+
prow_rnd = int(self.peak_row[idx]) + row_shift
|
137 |
+
pcol_rnd = int(self.peak_col[idx]) + col_shift
|
138 |
+
|
139 |
+
row_base = max(0, prow_rnd-self.psz//2)
|
140 |
+
col_base = max(0, pcol_rnd-self.psz//2)
|
141 |
+
|
142 |
+
if crop_img.max() != crop_img.min():
|
143 |
+
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
|
144 |
+
feature = (crop_img - _min) / (_max - _min)
|
145 |
+
else:
|
146 |
+
#logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
|
147 |
+
feature = crop_img
|
148 |
+
|
149 |
+
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
|
150 |
+
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz
|
151 |
+
|
152 |
+
return feature[np.newaxis], np.array([px, py]).astype(np.float32)
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return self.len
|
156 |
+
|