Niv Sardi commited on
Commit
6ef300c
1 Parent(s): 8beee8d

correctly import labels

Browse files

Signed-off-by: Niv Sardi <xaiki@evilgiggle.com>

Files changed (1) hide show
  1. python/augment.py +18 -4
python/augment.py CHANGED
@@ -25,6 +25,7 @@ mkdir.make_dirs([defaults.AUGMENTED_IMAGES_PATH, defaults.AUGMENTED_LABELS_PATH]
25
 
26
  logo_images = []
27
  logo_alphas = []
 
28
 
29
  background_images = [d for d in os.scandir(defaults.IMAGES_PATH)]
30
 
@@ -45,7 +46,7 @@ for d in os.scandir(defaults.LOGOS_DATA_PATH):
45
  else:
46
  png = svg2png(url=d.path)
47
  img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
48
- stats['ok'] += 1
49
 
50
  (h, w, c) = img.shape
51
  if c == 3:
@@ -59,10 +60,20 @@ for d in os.scandir(defaults.LOGOS_DATA_PATH):
59
 
60
  assert(w > 10)
61
  assert(h > 10)
 
 
 
62
  (b, g, r, _) = cv2.split(img)
63
  alpha = img[:, :, 3]/255
64
- logo_images.append(cv2.merge([b, g, r]))
65
- # XXX(xaiki): we pass alpha as a float32 heatmap, because imgaug is pretty strict about what data it will process
 
 
 
 
 
 
 
66
  logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
67
 
68
  except Exception as e:
@@ -98,6 +109,9 @@ with pipeline.pool(processes=-1, seed=1) as pool:
98
  anotations = []
99
  for k in range(math.floor(len(batch_aug.images_aug)/3)):
100
  logo_idx = (j+k*4)%len(batch_aug.images_aug)
 
 
 
101
  logo = batch_aug.images_aug[logo_idx]
102
 
103
  # XXX(xaiki): we get alpha from heatmap, but will only use one channel
@@ -106,7 +120,7 @@ with pipeline.pool(processes=-1, seed=1) as pool:
106
  try:
107
  img, bb, (w, h) = imtool.mix_alpha(img, logo, alpha[0], random.random(), random.random())
108
  c = bb.to_centroid((h, w, 1))
109
- anotations.append(c.to_anotation(0))
110
  except AssertionError as e:
111
  print(f'couldnt process {i}, {j}: {e}')
112
 
 
25
 
26
  logo_images = []
27
  logo_alphas = []
28
+ logo_labels = {}
29
 
30
  background_images = [d for d in os.scandir(defaults.IMAGES_PATH)]
31
 
 
46
  else:
47
  png = svg2png(url=d.path)
48
  img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
49
+ label = d.name.split('.')[0]
50
 
51
  (h, w, c) = img.shape
52
  if c == 3:
 
60
 
61
  assert(w > 10)
62
  assert(h > 10)
63
+
64
+ stats['ok'] += 1
65
+
66
  (b, g, r, _) = cv2.split(img)
67
  alpha = img[:, :, 3]/255
68
+ d = cv2.merge([b, g, r])
69
+
70
+ logo_images.append(d)
71
+ # tried id() tried __array_interface__, tried tagging, nothing works
72
+ logo_labels.update({d.tobytes(): label})
73
+
74
+ # XXX(xaiki): we pass alpha as a float32 heatmap,
75
+ # because imgaug is pretty strict about what data it will process
76
+ # and that we want the alpha layer to pass the same transformations as the orig
77
  logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
78
 
79
  except Exception as e:
 
109
  anotations = []
110
  for k in range(math.floor(len(batch_aug.images_aug)/3)):
111
  logo_idx = (j+k*4)%len(batch_aug.images_aug)
112
+
113
+ orig = batch_aug.images_unaug[logo_idx]
114
+ label = logo_labels[orig.tobytes()]
115
  logo = batch_aug.images_aug[logo_idx]
116
 
117
  # XXX(xaiki): we get alpha from heatmap, but will only use one channel
 
120
  try:
121
  img, bb, (w, h) = imtool.mix_alpha(img, logo, alpha[0], random.random(), random.random())
122
  c = bb.to_centroid((h, w, 1))
123
+ anotations.append(c.to_anotation(label))
124
  except AssertionError as e:
125
  print(f'couldnt process {i}, {j}: {e}')
126