George commited on
Commit
f803e88
1 Parent(s): e42312a

modified based on upl app.py commit

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. khandy/label/detect.py +76 -64
  3. khandy/split_utils.py +9 -11
app.py CHANGED
@@ -22,7 +22,7 @@ def inference(filename):
22
  image_height, image_width = image.shape[:2]
23
  boxes, confs, classes = detector.detect(image)
24
 
25
- for box, conf, class_ind in zip(boxes, confs, classes):
26
  box = box.astype(np.int32)
27
  box_width = box[2] - box[0] + 1
28
  box_height = box[3] - box[1] + 1
@@ -38,7 +38,8 @@ def inference(filename):
38
  if prob < 0.10:
39
  text = 'Unknown'
40
  else:
41
- text = '{}: {:.2f}%'.format(
 
42
  results[0]['latin_name'],
43
  100.0 * results[0]['probability']
44
  )
 
22
  image_height, image_width = image.shape[:2]
23
  boxes, confs, classes = detector.detect(image)
24
 
25
+ for box, _, _ in zip(boxes, confs, classes):
26
  box = box.astype(np.int32)
27
  box_width = box[2] - box[0] + 1
28
  box_height = box[3] - box[1] + 1
 
38
  if prob < 0.10:
39
  text = 'Unknown'
40
  else:
41
+ text = '{}({}): {:.2f}%'.format(
42
+ results[0]['chinese_name'],
43
  results[0]['latin_name'],
44
  100.0 * results[0]['probability']
45
  )
khandy/label/detect.py CHANGED
@@ -13,7 +13,7 @@ import lxml.builder
13
  import numpy as np
14
 
15
 
16
- __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
17
  'save_detect', 'convert_detect', 'replace_detect_label',
18
  'load_coco_class_names']
19
 
@@ -27,8 +27,8 @@ class DetectIrObject:
27
  y_min: float
28
  x_max: float
29
  y_max: float
30
-
31
-
32
  @dataclass
33
  class DetectIrRecord:
34
  """Intermediate Representation Format of Record
@@ -37,30 +37,30 @@ class DetectIrRecord:
37
  width: int
38
  height: int
39
  objects: List[DetectIrObject] = field(default_factory=list)
40
-
41
-
42
  @dataclass
43
  class PascalVocSource:
44
  database: str = ''
45
  annotation: str = ''
46
  image: str = ''
47
-
48
-
49
  @dataclass
50
  class PascalVocSize:
51
  height: int
52
  width: int
53
  depth: int
54
-
55
-
56
  @dataclass
57
  class PascalVocBndbox:
58
  xmin: float
59
  ymin: float
60
  xmax: float
61
  ymax: float
62
-
63
-
64
  @dataclass
65
  class PascalVocObject:
66
  name: str
@@ -68,8 +68,8 @@ class PascalVocObject:
68
  truncated: int = 0
69
  difficult: int = 0
70
  bndbox: Optional[PascalVocBndbox] = None
71
-
72
-
73
  @dataclass
74
  class PascalVocRecord:
75
  folder: str = ''
@@ -79,33 +79,33 @@ class PascalVocRecord:
79
  size: Optional[PascalVocSize] = None
80
  segmented: int = 0
81
  objects: List[PascalVocObject] = field(default_factory=list)
82
-
83
-
84
  class PascalVocHandler:
85
  @staticmethod
86
  def load(filename, **kwargs) -> PascalVocRecord:
87
  pascal_voc_record = PascalVocRecord()
88
-
89
  xml_tree = ET.parse(filename)
90
  pascal_voc_record.folder = xml_tree.find('folder').text
91
  pascal_voc_record.filename = xml_tree.find('filename').text
92
  pascal_voc_record.path = xml_tree.find('path').text
93
  pascal_voc_record.segmented = xml_tree.find('segmented').text
94
-
95
  source_tag = xml_tree.find('source')
96
  pascal_voc_record.source = PascalVocSource(
97
  database=source_tag.find('database').text,
98
  # annotation=source_tag.find('annotation').text,
99
  # image=source_tag.find('image').text
100
  )
101
-
102
  size_tag = xml_tree.find('size')
103
  pascal_voc_record.size = PascalVocSize(
104
  width=int(size_tag.find('width').text),
105
  height=int(size_tag.find('height').text),
106
  depth=int(size_tag.find('depth').text)
107
  )
108
-
109
  object_tags = xml_tree.findall('object')
110
  for index, object_tag in enumerate(object_tags):
111
  bndbox_tag = object_tag.find('bndbox')
@@ -124,7 +124,7 @@ class PascalVocHandler:
124
  )
125
  pascal_voc_record.objects.append(pascal_voc_object)
126
  return pascal_voc_record
127
-
128
  @staticmethod
129
  def save(filename, pascal_voc_record: PascalVocRecord):
130
  maker = lxml.builder.ElementMaker()
@@ -135,14 +135,14 @@ class PascalVocHandler:
135
  maker.source(
136
  maker.database(pascal_voc_record.source.database),
137
  ),
138
- maker.size(
139
  maker.width(str(pascal_voc_record.size.width)),
140
  maker.height(str(pascal_voc_record.size.height)),
141
  maker.depth(str(pascal_voc_record.size.depth)),
142
  ),
143
  maker.segmented(str(pascal_voc_record.segmented)),
144
  )
145
-
146
  for pascal_voc_object in pascal_voc_record.objects:
147
  object_tag = maker.object(
148
  maker.name(pascal_voc_object.name),
@@ -157,12 +157,13 @@ class PascalVocHandler:
157
  ),
158
  )
159
  xml.append(object_tag)
160
-
161
  if not filename.endswith('.xml'):
162
  filename = filename + '.xml'
163
  with open(filename, 'wb') as f:
164
- f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
165
-
 
166
  @staticmethod
167
  def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
168
  ir_record = DetectIrRecord(
@@ -180,7 +181,7 @@ class PascalVocHandler:
180
  )
181
  ir_record.objects.append(ir_object)
182
  return ir_record
183
-
184
  @staticmethod
185
  def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
186
  pascal_voc_record = PascalVocRecord(
@@ -203,10 +204,11 @@ class PascalVocHandler:
203
  )
204
  pascal_voc_record.objects.append(pascal_voc_object)
205
  return pascal_voc_record
206
-
207
-
208
  class _NumpyEncoder(json.JSONEncoder):
209
  """ Special json encoder for numpy types """
 
210
  def default(self, obj):
211
  if isinstance(obj, (np.bool_,)):
212
  return bool(obj)
@@ -279,7 +281,7 @@ class LabelmeHandler:
279
  )
280
  ir_record.objects.append(ir_object)
281
  return ir_record
282
-
283
  @staticmethod
284
  def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
285
  labelme_record = LabelmeRecord(
@@ -291,12 +293,12 @@ class LabelmeHandler:
291
  labelme_shape = LabelmeShape(
292
  label=ir_object.label,
293
  shape_type='rectangle',
294
- points=[[ir_object.x_min, ir_object.y_min],
295
  [ir_object.x_max, ir_object.y_max]]
296
  )
297
  labelme_record.shapes.append(labelme_shape)
298
  return labelme_record
299
-
300
 
301
  @dataclass
302
  class YoloObject:
@@ -305,16 +307,16 @@ class YoloObject:
305
  y_center: float
306
  width: float
307
  height: float
308
-
309
-
310
  @dataclass
311
  class YoloRecord:
312
  filename: Optional[str] = None
313
  width: Optional[int] = None
314
  height: Optional[int] = None
315
  objects: List[YoloObject] = field(default_factory=list)
316
-
317
-
318
  class YoloHandler:
319
  @staticmethod
320
  def load(filename, **kwargs) -> YoloRecord:
@@ -341,7 +343,8 @@ class YoloHandler:
341
  def save(filename, yolo_record: YoloRecord):
342
  records = []
343
  for object in yolo_record.objects:
344
- records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
 
345
  if not filename.endswith('.txt'):
346
  filename = filename + '.txt'
347
  khandy.save_list(filename, records)
@@ -354,10 +357,14 @@ class YoloHandler:
354
  height=yolo_record.height
355
  )
356
  for yolo_object in yolo_record.objects:
357
- x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
358
- y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
359
- x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
360
- y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
 
 
 
 
361
  ir_object = DetectIrObject(
362
  label=yolo_object.label,
363
  x_min=x_min,
@@ -367,7 +374,7 @@ class YoloHandler:
367
  )
368
  ir_record.objects.append(ir_object)
369
  return ir_record
370
-
371
  @staticmethod
372
  def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
373
  yolo_record = YoloRecord(
@@ -376,8 +383,10 @@ class YoloHandler:
376
  height=ir_record.height
377
  )
378
  for ir_object in ir_record.objects:
379
- x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
380
- y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
 
 
381
  width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
382
  height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
383
  yolo_object = YoloObject(
@@ -389,8 +398,8 @@ class YoloHandler:
389
  )
390
  yolo_record.objects.append(yolo_object)
391
  return yolo_record
392
-
393
-
394
  @dataclass
395
  class CocoObject:
396
  label: str
@@ -398,29 +407,29 @@ class CocoObject:
398
  y_min: float
399
  width: float
400
  height: float
401
-
402
-
403
  @dataclass
404
  class CocoRecord:
405
  filename: str
406
  width: int
407
  height: int
408
  objects: List[CocoObject] = field(default_factory=list)
409
-
410
 
411
  class CocoHandler:
412
  @staticmethod
413
  def load(filename, **kwargs) -> List[CocoRecord]:
414
  json_data = khandy.load_json(filename)
415
-
416
  images = json_data['images']
417
  annotations = json_data['annotations']
418
  categories = json_data['categories']
419
-
420
  label_map = {}
421
  for cat_item in categories:
422
  label_map[cat_item['id']] = cat_item['name']
423
-
424
  coco_records = OrderedDict()
425
  for image_item in images:
426
  coco_records[image_item['id']] = CocoRecord(
@@ -428,7 +437,7 @@ class CocoHandler:
428
  width=image_item['width'],
429
  height=image_item['height'],
430
  objects=[])
431
-
432
  for annotation_item in annotations:
433
  coco_object = CocoObject(
434
  label=label_map[annotation_item['category_id']],
@@ -436,9 +445,10 @@ class CocoHandler:
436
  y_min=annotation_item['bbox'][1],
437
  width=annotation_item['bbox'][2],
438
  height=annotation_item['bbox'][3])
439
- coco_records[annotation_item['image_id']].objects.append(coco_object)
 
440
  return list(coco_records.values())
441
-
442
  @staticmethod
443
  def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
444
  ir_record = DetectIrRecord(
@@ -474,8 +484,8 @@ class CocoHandler:
474
  )
475
  coco_record.objects.append(coco_object)
476
  return coco_record
477
-
478
-
479
  def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
480
  if fmt == 'labelme':
481
  labelme_record = LabelmeHandler.load(filename, **kwargs)
@@ -488,12 +498,13 @@ def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
488
  ir_record = PascalVocHandler.to_ir(pascal_voc_record)
489
  elif fmt == 'coco':
490
  coco_records = CocoHandler.load(filename, **kwargs)
491
- ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
 
492
  else:
493
  raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
494
  return ir_record
495
-
496
-
497
  def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
498
  os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
499
  if out_fmt == 'labelme':
@@ -527,9 +538,11 @@ def _get_format(record):
527
 
528
 
529
  def convert_detect(record, out_fmt):
530
- allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
 
531
  if out_fmt not in allowed_fmts:
532
- raise ValueError("Unsupported label format conversions for given out_fmt")
 
533
  if out_fmt in _get_format(record):
534
  return record
535
 
@@ -545,7 +558,7 @@ def convert_detect(record, out_fmt):
545
  ir_record = record
546
  else:
547
  raise TypeError('Unsupported type for record')
548
-
549
  if out_fmt in ('ir', 'detect_ir'):
550
  dst_record = ir_record
551
  elif out_fmt == 'labelme':
@@ -557,7 +570,7 @@ def convert_detect(record, out_fmt):
557
  elif out_fmt == 'coco':
558
  dst_record = CocoHandler.from_ir(ir_record)
559
  return dst_record
560
-
561
 
562
  def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
563
  dst_record = copy.deepcopy(record)
@@ -579,4 +592,3 @@ def load_coco_class_names(filename):
579
  json_data = khandy.load_json(filename)
580
  categories = json_data['categories']
581
  return [cat_item['name'] for cat_item in categories]
582
-
 
13
  import numpy as np
14
 
15
 
16
+ __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
17
  'save_detect', 'convert_detect', 'replace_detect_label',
18
  'load_coco_class_names']
19
 
 
27
  y_min: float
28
  x_max: float
29
  y_max: float
30
+
31
+
32
  @dataclass
33
  class DetectIrRecord:
34
  """Intermediate Representation Format of Record
 
37
  width: int
38
  height: int
39
  objects: List[DetectIrObject] = field(default_factory=list)
40
+
41
+
42
  @dataclass
43
  class PascalVocSource:
44
  database: str = ''
45
  annotation: str = ''
46
  image: str = ''
47
+
48
+
49
  @dataclass
50
  class PascalVocSize:
51
  height: int
52
  width: int
53
  depth: int
54
+
55
+
56
  @dataclass
57
  class PascalVocBndbox:
58
  xmin: float
59
  ymin: float
60
  xmax: float
61
  ymax: float
62
+
63
+
64
  @dataclass
65
  class PascalVocObject:
66
  name: str
 
68
  truncated: int = 0
69
  difficult: int = 0
70
  bndbox: Optional[PascalVocBndbox] = None
71
+
72
+
73
  @dataclass
74
  class PascalVocRecord:
75
  folder: str = ''
 
79
  size: Optional[PascalVocSize] = None
80
  segmented: int = 0
81
  objects: List[PascalVocObject] = field(default_factory=list)
82
+
83
+
84
  class PascalVocHandler:
85
  @staticmethod
86
  def load(filename, **kwargs) -> PascalVocRecord:
87
  pascal_voc_record = PascalVocRecord()
88
+
89
  xml_tree = ET.parse(filename)
90
  pascal_voc_record.folder = xml_tree.find('folder').text
91
  pascal_voc_record.filename = xml_tree.find('filename').text
92
  pascal_voc_record.path = xml_tree.find('path').text
93
  pascal_voc_record.segmented = xml_tree.find('segmented').text
94
+
95
  source_tag = xml_tree.find('source')
96
  pascal_voc_record.source = PascalVocSource(
97
  database=source_tag.find('database').text,
98
  # annotation=source_tag.find('annotation').text,
99
  # image=source_tag.find('image').text
100
  )
101
+
102
  size_tag = xml_tree.find('size')
103
  pascal_voc_record.size = PascalVocSize(
104
  width=int(size_tag.find('width').text),
105
  height=int(size_tag.find('height').text),
106
  depth=int(size_tag.find('depth').text)
107
  )
108
+
109
  object_tags = xml_tree.findall('object')
110
  for index, object_tag in enumerate(object_tags):
111
  bndbox_tag = object_tag.find('bndbox')
 
124
  )
125
  pascal_voc_record.objects.append(pascal_voc_object)
126
  return pascal_voc_record
127
+
128
  @staticmethod
129
  def save(filename, pascal_voc_record: PascalVocRecord):
130
  maker = lxml.builder.ElementMaker()
 
135
  maker.source(
136
  maker.database(pascal_voc_record.source.database),
137
  ),
138
+ maker.size(
139
  maker.width(str(pascal_voc_record.size.width)),
140
  maker.height(str(pascal_voc_record.size.height)),
141
  maker.depth(str(pascal_voc_record.size.depth)),
142
  ),
143
  maker.segmented(str(pascal_voc_record.segmented)),
144
  )
145
+
146
  for pascal_voc_object in pascal_voc_record.objects:
147
  object_tag = maker.object(
148
  maker.name(pascal_voc_object.name),
 
157
  ),
158
  )
159
  xml.append(object_tag)
160
+
161
  if not filename.endswith('.xml'):
162
  filename = filename + '.xml'
163
  with open(filename, 'wb') as f:
164
+ f.write(lxml.etree.tostring(
165
+ xml, pretty_print=True, encoding='utf-8'))
166
+
167
  @staticmethod
168
  def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
169
  ir_record = DetectIrRecord(
 
181
  )
182
  ir_record.objects.append(ir_object)
183
  return ir_record
184
+
185
  @staticmethod
186
  def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
187
  pascal_voc_record = PascalVocRecord(
 
204
  )
205
  pascal_voc_record.objects.append(pascal_voc_object)
206
  return pascal_voc_record
207
+
208
+
209
  class _NumpyEncoder(json.JSONEncoder):
210
  """ Special json encoder for numpy types """
211
+
212
  def default(self, obj):
213
  if isinstance(obj, (np.bool_,)):
214
  return bool(obj)
 
281
  )
282
  ir_record.objects.append(ir_object)
283
  return ir_record
284
+
285
  @staticmethod
286
  def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
287
  labelme_record = LabelmeRecord(
 
293
  labelme_shape = LabelmeShape(
294
  label=ir_object.label,
295
  shape_type='rectangle',
296
+ points=[[ir_object.x_min, ir_object.y_min],
297
  [ir_object.x_max, ir_object.y_max]]
298
  )
299
  labelme_record.shapes.append(labelme_shape)
300
  return labelme_record
301
+
302
 
303
  @dataclass
304
  class YoloObject:
 
307
  y_center: float
308
  width: float
309
  height: float
310
+
311
+
312
  @dataclass
313
  class YoloRecord:
314
  filename: Optional[str] = None
315
  width: Optional[int] = None
316
  height: Optional[int] = None
317
  objects: List[YoloObject] = field(default_factory=list)
318
+
319
+
320
  class YoloHandler:
321
  @staticmethod
322
  def load(filename, **kwargs) -> YoloRecord:
 
343
  def save(filename, yolo_record: YoloRecord):
344
  records = []
345
  for object in yolo_record.objects:
346
+ records.append(
347
+ f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
348
  if not filename.endswith('.txt'):
349
  filename = filename + '.txt'
350
  khandy.save_list(filename, records)
 
357
  height=yolo_record.height
358
  )
359
  for yolo_object in yolo_record.objects:
360
+ x_min = (yolo_object.x_center - 0.5 *
361
+ yolo_object.width) * yolo_record.width
362
+ y_min = (yolo_object.y_center - 0.5 *
363
+ yolo_object.height) * yolo_record.height
364
+ x_max = (yolo_object.x_center + 0.5 *
365
+ yolo_object.width) * yolo_record.width
366
+ y_max = (yolo_object.y_center + 0.5 *
367
+ yolo_object.height) * yolo_record.height
368
  ir_object = DetectIrObject(
369
  label=yolo_object.label,
370
  x_min=x_min,
 
374
  )
375
  ir_record.objects.append(ir_object)
376
  return ir_record
377
+
378
  @staticmethod
379
  def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
380
  yolo_record = YoloRecord(
 
383
  height=ir_record.height
384
  )
385
  for ir_object in ir_record.objects:
386
+ x_center = (ir_object.x_max + ir_object.x_min) / \
387
+ (2 * ir_record.width)
388
+ y_center = (ir_object.y_max + ir_object.y_min) / \
389
+ (2 * ir_record.height)
390
  width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
391
  height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
392
  yolo_object = YoloObject(
 
398
  )
399
  yolo_record.objects.append(yolo_object)
400
  return yolo_record
401
+
402
+
403
  @dataclass
404
  class CocoObject:
405
  label: str
 
407
  y_min: float
408
  width: float
409
  height: float
410
+
411
+
412
  @dataclass
413
  class CocoRecord:
414
  filename: str
415
  width: int
416
  height: int
417
  objects: List[CocoObject] = field(default_factory=list)
418
+
419
 
420
  class CocoHandler:
421
  @staticmethod
422
  def load(filename, **kwargs) -> List[CocoRecord]:
423
  json_data = khandy.load_json(filename)
424
+
425
  images = json_data['images']
426
  annotations = json_data['annotations']
427
  categories = json_data['categories']
428
+
429
  label_map = {}
430
  for cat_item in categories:
431
  label_map[cat_item['id']] = cat_item['name']
432
+
433
  coco_records = OrderedDict()
434
  for image_item in images:
435
  coco_records[image_item['id']] = CocoRecord(
 
437
  width=image_item['width'],
438
  height=image_item['height'],
439
  objects=[])
440
+
441
  for annotation_item in annotations:
442
  coco_object = CocoObject(
443
  label=label_map[annotation_item['category_id']],
 
445
  y_min=annotation_item['bbox'][1],
446
  width=annotation_item['bbox'][2],
447
  height=annotation_item['bbox'][3])
448
+ coco_records[annotation_item['image_id']
449
+ ].objects.append(coco_object)
450
  return list(coco_records.values())
451
+
452
  @staticmethod
453
  def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
454
  ir_record = DetectIrRecord(
 
484
  )
485
  coco_record.objects.append(coco_object)
486
  return coco_record
487
+
488
+
489
  def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
490
  if fmt == 'labelme':
491
  labelme_record = LabelmeHandler.load(filename, **kwargs)
 
498
  ir_record = PascalVocHandler.to_ir(pascal_voc_record)
499
  elif fmt == 'coco':
500
  coco_records = CocoHandler.load(filename, **kwargs)
501
+ ir_record = [CocoHandler.to_ir(coco_record)
502
+ for coco_record in coco_records]
503
  else:
504
  raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
505
  return ir_record
506
+
507
+
508
  def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
509
  os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
510
  if out_fmt == 'labelme':
 
538
 
539
 
540
  def convert_detect(record, out_fmt):
541
+ allowed_fmts = ('labelme', 'yolo', 'voc', 'coco',
542
+ 'pascal', 'pascal_voc', 'ir', 'detect_ir')
543
  if out_fmt not in allowed_fmts:
544
+ raise ValueError(
545
+ "Unsupported label format conversions for given out_fmt")
546
  if out_fmt in _get_format(record):
547
  return record
548
 
 
558
  ir_record = record
559
  else:
560
  raise TypeError('Unsupported type for record')
561
+
562
  if out_fmt in ('ir', 'detect_ir'):
563
  dst_record = ir_record
564
  elif out_fmt == 'labelme':
 
570
  elif out_fmt == 'coco':
571
  dst_record = CocoHandler.from_ir(ir_record)
572
  return dst_record
573
+
574
 
575
  def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
576
  dst_record = copy.deepcopy(record)
 
592
  json_data = khandy.load_json(filename)
593
  categories = json_data['categories']
594
  return [cat_item['name'] for cat_item in categories]
 
khandy/split_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import numbers
2
- from collections import Sequence
3
 
4
  import numpy as np
5
 
@@ -15,7 +15,7 @@ def split_by_num(x, num_splits, strict=True):
15
  # NB: np.ndarray is not Sequence
16
  assert isinstance(x, (Sequence, np.ndarray))
17
  assert isinstance(num_splits, numbers.Integral)
18
-
19
  if strict:
20
  assert len(x) % num_splits == 0
21
  split_size = (len(x) + num_splits - 1) // num_splits
@@ -23,8 +23,8 @@ def split_by_num(x, num_splits, strict=True):
23
  for i in range(0, len(x), split_size):
24
  out_list.append(x[i: i + split_size])
25
  return out_list
26
-
27
-
28
  def split_by_size(x, sizes):
29
  """
30
  References:
@@ -34,7 +34,7 @@ def split_by_size(x, sizes):
34
  # NB: np.ndarray is not Sequence
35
  assert isinstance(x, (Sequence, np.ndarray))
36
  assert isinstance(sizes, (list, tuple))
37
-
38
  assert sum(sizes) == len(x)
39
  out_list = []
40
  start_index = 0
@@ -42,8 +42,8 @@ def split_by_size(x, sizes):
42
  out_list.append(x[start_index: start_index + size])
43
  start_index += size
44
  return out_list
45
-
46
-
47
  def split_by_slice(x, slices):
48
  """
49
  References:
@@ -52,7 +52,7 @@ def split_by_slice(x, slices):
52
  # NB: np.ndarray is not Sequence
53
  assert isinstance(x, (Sequence, np.ndarray))
54
  assert isinstance(slices, (list, tuple))
55
-
56
  out_list = []
57
  indices = [0] + list(slices) + [len(x)]
58
  for i in range(len(slices) + 1):
@@ -64,10 +64,8 @@ def split_by_ratio(x, ratios):
64
  # NB: np.ndarray is not Sequence
65
  assert isinstance(x, (Sequence, np.ndarray))
66
  assert isinstance(ratios, (list, tuple))
67
-
68
  pdf = [k / sum(ratios) for k in ratios]
69
  cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
70
  indices = [int(round(len(x) * k)) for k in cdf]
71
  return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
72
-
73
-
 
1
  import numbers
2
+ from collections.abc import Sequence
3
 
4
  import numpy as np
5
 
 
15
  # NB: np.ndarray is not Sequence
16
  assert isinstance(x, (Sequence, np.ndarray))
17
  assert isinstance(num_splits, numbers.Integral)
18
+
19
  if strict:
20
  assert len(x) % num_splits == 0
21
  split_size = (len(x) + num_splits - 1) // num_splits
 
23
  for i in range(0, len(x), split_size):
24
  out_list.append(x[i: i + split_size])
25
  return out_list
26
+
27
+
28
  def split_by_size(x, sizes):
29
  """
30
  References:
 
34
  # NB: np.ndarray is not Sequence
35
  assert isinstance(x, (Sequence, np.ndarray))
36
  assert isinstance(sizes, (list, tuple))
37
+
38
  assert sum(sizes) == len(x)
39
  out_list = []
40
  start_index = 0
 
42
  out_list.append(x[start_index: start_index + size])
43
  start_index += size
44
  return out_list
45
+
46
+
47
  def split_by_slice(x, slices):
48
  """
49
  References:
 
52
  # NB: np.ndarray is not Sequence
53
  assert isinstance(x, (Sequence, np.ndarray))
54
  assert isinstance(slices, (list, tuple))
55
+
56
  out_list = []
57
  indices = [0] + list(slices) + [len(x)]
58
  for i in range(len(slices) + 1):
 
64
  # NB: np.ndarray is not Sequence
65
  assert isinstance(x, (Sequence, np.ndarray))
66
  assert isinstance(ratios, (list, tuple))
67
+
68
  pdf = [k / sum(ratios) for k in ratios]
69
  cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
70
  indices = [int(round(len(x) * k)) for k in cdf]
71
  return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]