Niv Sardi commited on
Commit
9996fa3
1 Parent(s): 74a29fd

augmentor: pass alphas as heatmaps

Browse files

Signed-off-by: Niv Sardi <xaiki@evilgiggle.com>

Files changed (2) hide show
  1. python/augment.py +17 -6
  2. python/imtool.py +23 -4
python/augment.py CHANGED
@@ -24,6 +24,8 @@ BATCH_SIZE = 16
24
  mkdir.make_dirs([defaults.AUGMENTED_IMAGES_PATH, defaults.AUGMENTED_LABELS_PATH])
25
 
26
  logo_images = []
 
 
27
  background_images = [d for d in os.scandir(defaults.IMAGES_PATH)]
28
 
29
  stats = {
@@ -57,14 +59,18 @@ for d in os.scandir(defaults.LOGOS_DATA_PATH):
57
 
58
  assert(w > 10)
59
  assert(h > 10)
 
 
 
 
 
60
 
61
- logo_images.append(img)
62
  except Exception as e:
63
  stats['failed'] += 1
64
  print(f'error loading: {d.path}: {e}')
65
 
66
  print(stats)
67
- batches = [UnnormalizedBatch(images=logo_images[i:i+BATCH_SIZE])
68
  for i in range(math.floor(len(logo_images)/BATCH_SIZE))]
69
 
70
  # We use a single, very fast augmenter here to show that batches
@@ -91,12 +97,17 @@ with pipeline.pool(processes=-1, seed=1) as pool:
91
 
92
  anotations = []
93
  for k in range(math.floor(len(batch_aug.images_aug)/3)):
94
- logo = batch_aug.images_aug[(j+k)%len(batch_aug.images_aug)]
 
 
 
 
 
95
  try:
96
- img, bb, (w, h) = imtool.mix(img, logo, random.random(), random.random())
97
  anotations.append(f'0 {bb.x/w} {bb.y/h} {bb.w/w} {bb.h/h}')
98
- except AssertionError:
99
- print(f'couldnt process {i}, {j}')
100
 
101
  try:
102
  cv2.imwrite(f'{defaults.AUGMENTED_IMAGES_PATH}/{basename}.png', img)
 
24
  mkdir.make_dirs([defaults.AUGMENTED_IMAGES_PATH, defaults.AUGMENTED_LABELS_PATH])
25
 
26
  logo_images = []
27
+ logo_alphas = []
28
+
29
  background_images = [d for d in os.scandir(defaults.IMAGES_PATH)]
30
 
31
  stats = {
 
59
 
60
  assert(w > 10)
61
  assert(h > 10)
62
+ (b, g, r, _) = cv2.split(img)
63
+ alpha = img[:, :, 3]/255
64
+ logo_images.append(cv2.merge([b, g, r]))
65
+ # XXX(xaiki): we pass alpha as a float32 heatmap, because imgaug is pretty strict about what data it will process
66
+ logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
67
 
 
68
  except Exception as e:
69
  stats['failed'] += 1
70
  print(f'error loading: {d.path}: {e}')
71
 
72
  print(stats)
73
+ batches = [UnnormalizedBatch(images=logo_images[i:i+BATCH_SIZE],heatmaps=logo_alphas[i:i+BATCH_SIZE])
74
  for i in range(math.floor(len(logo_images)/BATCH_SIZE))]
75
 
76
  # We use a single, very fast augmenter here to show that batches
 
97
 
98
  anotations = []
99
  for k in range(math.floor(len(batch_aug.images_aug)/3)):
100
+ logo_idx = (j+k*4)%len(batch_aug.images_aug)
101
+ logo = batch_aug.images_aug[logo_idx]
102
+
103
+ # XXX(xaiki): we get alpha from heatmap, but will only use one channel
104
+ # we could make mix_alpha into mix_mask and pass all 3 chanels
105
+ alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
106
  try:
107
+ img, bb, (w, h) = imtool.mix_alpha(img, logo, alpha[0], random.random(), random.random())
108
  anotations.append(f'0 {bb.x/w} {bb.y/h} {bb.w/w} {bb.h/h}')
109
+ except AssertionError as e:
110
+ print(f'couldnt process {i}, {j}: {e}')
111
 
112
  try:
113
  cv2.imwrite(f'{defaults.AUGMENTED_IMAGES_PATH}/{basename}.png', img)
python/imtool.py CHANGED
@@ -89,12 +89,32 @@ def remove_white(img):
89
 
90
  return rect
91
 
 
92
  def mix(a, b, fx, fy):
 
 
 
 
 
93
  (ah, aw, ac) = a.shape
94
  (bh, bw, bc) = b.shape
95
 
96
- assert(aw > bw)
97
- assert(ah > bh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  x = math.floor(fx*(aw - bw))
100
  y = math.floor(fy*(ah - bh))
@@ -102,8 +122,7 @@ def mix(a, b, fx, fy):
102
  # handle transparency
103
  mat = a[y:y+bh,x:x+bw]
104
  cols = b[:, :, :3]
105
- alpha = b[:, :, 3]/255
106
- mask = np.dstack((alpha, alpha, alpha))
107
 
108
  a[y:y+bh,x:x+bw] = mat * (1 - mask) + cols * mask
109
 
 
89
 
90
  return rect
91
 
92
+
93
  def mix(a, b, fx, fy):
94
+ alpha = b[:, :, 3]/255
95
+
96
+ return _mix_alpha(a, b, alpha, fx, fy)
97
+
98
+ def mix_alpha(a, b, ba, fx, fy):
99
  (ah, aw, ac) = a.shape
100
  (bh, bw, bc) = b.shape
101
 
102
+ if (aw < bw or ah < bh):
103
+ f = 0.2*aw/bw
104
+ print(f'resizing, factor {f} to fit in {aw}x{ah}\n -- {bw}x{bh} => {floor_point(bw*f, bh*f)}')
105
+ r = cv2.resize(b, floor_point(bw*f, bh*f), interpolation = cv2.INTER_LINEAR)
106
+ rba = cv2.resize(ba, floor_point(bw*f, bh*f), interpolation = cv2.INTER_LINEAR)
107
+
108
+ return mix_alpha(a, r, rba, fx, fy)
109
+
110
+ assert bw > 10, f'b({bw}) too small'
111
+ assert bh > 10, f'b({bh}) too small'
112
+
113
+ return _mix_alpha(a, b, ba, fx, fy)
114
+
115
+ def _mix_alpha(a, b, ba, fx, fy):
116
+ (ah, aw, ac) = a.shape
117
+ (bh, bw, bc) = b.shape
118
 
119
  x = math.floor(fx*(aw - bw))
120
  y = math.floor(fy*(ah - bh))
 
122
  # handle transparency
123
  mat = a[y:y+bh,x:x+bw]
124
  cols = b[:, :, :3]
125
+ mask = np.dstack((ba, ba, ba))
 
126
 
127
  a[y:y+bh,x:x+bw] = mat * (1 - mask) + cols * mask
128