insecta / khandy /label /detect.py
admin
sync
67a9b5d
raw
history blame
20 kB
import os
import copy
import json
import dataclasses
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Optional, List
import xml.etree.ElementTree as ET
import khandy
import lxml
import lxml.builder
import numpy as np
__all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
'save_detect', 'convert_detect', 'replace_detect_label',
'load_coco_class_names']
@dataclass
class DetectIrObject:
"""Intermediate Representation Format of Object
"""
label: str
x_min: float
y_min: float
x_max: float
y_max: float
@dataclass
class DetectIrRecord:
"""Intermediate Representation Format of Record
"""
filename: str
width: int
height: int
objects: List[DetectIrObject] = field(default_factory=list)
@dataclass
class PascalVocSource:
database: str = ''
annotation: str = ''
image: str = ''
@dataclass
class PascalVocSize:
height: int
width: int
depth: int
@dataclass
class PascalVocBndbox:
xmin: float
ymin: float
xmax: float
ymax: float
@dataclass
class PascalVocObject:
name: str
pose: str = 'Unspecified'
truncated: int = 0
difficult: int = 0
bndbox: Optional[PascalVocBndbox] = None
@dataclass
class PascalVocRecord:
folder: str = ''
filename: str = ''
path: str = ''
source: PascalVocSource = PascalVocSource()
size: Optional[PascalVocSize] = None
segmented: int = 0
objects: List[PascalVocObject] = field(default_factory=list)
class PascalVocHandler:
@staticmethod
def load(filename, **kwargs) -> PascalVocRecord:
pascal_voc_record = PascalVocRecord()
xml_tree = ET.parse(filename)
pascal_voc_record.folder = xml_tree.find('folder').text
pascal_voc_record.filename = xml_tree.find('filename').text
pascal_voc_record.path = xml_tree.find('path').text
pascal_voc_record.segmented = xml_tree.find('segmented').text
source_tag = xml_tree.find('source')
pascal_voc_record.source = PascalVocSource(
database=source_tag.find('database').text,
# annotation=source_tag.find('annotation').text,
# image=source_tag.find('image').text
)
size_tag = xml_tree.find('size')
pascal_voc_record.size = PascalVocSize(
width=int(size_tag.find('width').text),
height=int(size_tag.find('height').text),
depth=int(size_tag.find('depth').text)
)
object_tags = xml_tree.findall('object')
for index, object_tag in enumerate(object_tags):
bndbox_tag = object_tag.find('bndbox')
bndbox = PascalVocBndbox(
xmin=float(bndbox_tag.find('xmin').text) - 1,
ymin=float(bndbox_tag.find('ymin').text) - 1,
xmax=float(bndbox_tag.find('xmax').text) - 1,
ymax=float(bndbox_tag.find('ymax').text) - 1
)
pascal_voc_object = PascalVocObject(
name=object_tag.find('name').text,
pose=object_tag.find('pose').text,
truncated=object_tag.find('truncated').text,
difficult=object_tag.find('difficult').text,
bndbox=bndbox
)
pascal_voc_record.objects.append(pascal_voc_object)
return pascal_voc_record
@staticmethod
def save(filename, pascal_voc_record: PascalVocRecord):
maker = lxml.builder.ElementMaker()
xml = maker.annotation(
maker.folder(pascal_voc_record.folder),
maker.filename(pascal_voc_record.filename),
maker.path(pascal_voc_record.path),
maker.source(
maker.database(pascal_voc_record.source.database),
),
maker.size(
maker.width(str(pascal_voc_record.size.width)),
maker.height(str(pascal_voc_record.size.height)),
maker.depth(str(pascal_voc_record.size.depth)),
),
maker.segmented(str(pascal_voc_record.segmented)),
)
for pascal_voc_object in pascal_voc_record.objects:
object_tag = maker.object(
maker.name(pascal_voc_object.name),
maker.pose(pascal_voc_object.pose),
maker.truncated(str(pascal_voc_object.truncated)),
maker.difficult(str(pascal_voc_object.difficult)),
maker.bndbox(
maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
),
)
xml.append(object_tag)
if not filename.endswith('.xml'):
filename = filename + '.xml'
with open(filename, 'wb') as f:
f.write(lxml.etree.tostring(
xml, pretty_print=True, encoding='utf-8'))
@staticmethod
def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
ir_record = DetectIrRecord(
filename=pascal_voc_record.filename,
width=pascal_voc_record.size.width,
height=pascal_voc_record.size.height
)
for pascal_voc_object in pascal_voc_record.objects:
ir_object = DetectIrObject(
label=pascal_voc_object.name,
x_min=pascal_voc_object.bndbox.xmin,
y_min=pascal_voc_object.bndbox.ymin,
x_max=pascal_voc_object.bndbox.xmax,
y_max=pascal_voc_object.bndbox.ymax
)
ir_record.objects.append(ir_object)
return ir_record
@staticmethod
def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
pascal_voc_record = PascalVocRecord(
filename=ir_record.filename,
size=PascalVocSize(
width=ir_record.width,
height=ir_record.height,
depth=3
)
)
for ir_object in ir_record.objects:
pascal_voc_object = PascalVocObject(
name=ir_object.label,
bndbox=PascalVocBndbox(
xmin=ir_object.x_min,
ymin=ir_object.y_min,
xmax=ir_object.x_max,
ymax=ir_object.y_max,
)
)
pascal_voc_record.objects.append(pascal_voc_object)
return pascal_voc_record
class _NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
def default(self, obj):
if isinstance(obj, (np.bool_,)):
return bool(obj)
elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32,
np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
@dataclass
class LabelmeShape:
label: str
points: np.ndarray
shape_type: str
flags: dict = field(default_factory=dict)
group_id: Optional[int] = None
def __post_init__(self):
self.points = np.asarray(self.points)
@dataclass
class LabelmeRecord:
version: str = '4.5.6'
flags: dict = field(default_factory=dict)
shapes: List[LabelmeShape] = field(default_factory=list)
imagePath: Optional[str] = None
imageData: Optional[str] = None
imageHeight: Optional[int] = None
imageWidth: Optional[int] = None
def __post_init__(self):
for k, shape in enumerate(self.shapes):
self.shapes[k] = LabelmeShape(**shape)
class LabelmeHandler:
@staticmethod
def load(filename, **kwargs) -> LabelmeRecord:
json_content = khandy.load_json(filename)
return LabelmeRecord(**json_content)
@staticmethod
def save(filename, labelme_record: LabelmeRecord):
json_content = dataclasses.asdict(labelme_record)
khandy.save_json(filename, json_content, cls=_NumpyEncoder)
@staticmethod
def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
ir_record = DetectIrRecord(
filename=labelme_record.imagePath,
width=labelme_record.imageWidth,
height=labelme_record.imageHeight
)
for labelme_shape in labelme_record.shapes:
if labelme_shape.shape_type != 'rectangle':
continue
ir_object = DetectIrObject(
label=labelme_shape.label,
x_min=labelme_shape.points[0][0],
y_min=labelme_shape.points[0][1],
x_max=labelme_shape.points[1][0],
y_max=labelme_shape.points[1][1],
)
ir_record.objects.append(ir_object)
return ir_record
@staticmethod
def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
labelme_record = LabelmeRecord(
imagePath=ir_record.filename,
imageWidth=ir_record.width,
imageHeight=ir_record.height
)
for ir_object in ir_record.objects:
labelme_shape = LabelmeShape(
label=ir_object.label,
shape_type='rectangle',
points=[[ir_object.x_min, ir_object.y_min],
[ir_object.x_max, ir_object.y_max]]
)
labelme_record.shapes.append(labelme_shape)
return labelme_record
@dataclass
class YoloObject:
label: str
x_center: float
y_center: float
width: float
height: float
@dataclass
class YoloRecord:
filename: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
objects: List[YoloObject] = field(default_factory=list)
class YoloHandler:
@staticmethod
def load(filename, **kwargs) -> YoloRecord:
assert 'image_filename' in kwargs
assert 'width' in kwargs and 'height' in kwargs
records = khandy.load_list(filename)
yolo_record = YoloRecord(
filename=kwargs.get('image_filename'),
width=kwargs.get('width'),
height=kwargs.get('height'))
for record in records:
record_parts = record.split()
yolo_record.objects.append(YoloObject(
label=record_parts[0],
x_center=float(record_parts[1]),
y_center=float(record_parts[2]),
width=float(record_parts[3]),
height=float(record_parts[4]),
))
return yolo_record
@staticmethod
def save(filename, yolo_record: YoloRecord):
records = []
for object in yolo_record.objects:
records.append(
f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
if not filename.endswith('.txt'):
filename = filename + '.txt'
khandy.save_list(filename, records)
@staticmethod
def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
ir_record = DetectIrRecord(
filename=yolo_record.filename,
width=yolo_record.width,
height=yolo_record.height
)
for yolo_object in yolo_record.objects:
x_min = (yolo_object.x_center - 0.5 *
yolo_object.width) * yolo_record.width
y_min = (yolo_object.y_center - 0.5 *
yolo_object.height) * yolo_record.height
x_max = (yolo_object.x_center + 0.5 *
yolo_object.width) * yolo_record.width
y_max = (yolo_object.y_center + 0.5 *
yolo_object.height) * yolo_record.height
ir_object = DetectIrObject(
label=yolo_object.label,
x_min=x_min,
y_min=y_min,
x_max=x_max,
y_max=y_max
)
ir_record.objects.append(ir_object)
return ir_record
@staticmethod
def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
yolo_record = YoloRecord(
filename=ir_record.filename,
width=ir_record.width,
height=ir_record.height
)
for ir_object in ir_record.objects:
x_center = (ir_object.x_max + ir_object.x_min) / \
(2 * ir_record.width)
y_center = (ir_object.y_max + ir_object.y_min) / \
(2 * ir_record.height)
width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
yolo_object = YoloObject(
label=ir_object.label,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
)
yolo_record.objects.append(yolo_object)
return yolo_record
@dataclass
class CocoObject:
label: str
x_min: float
y_min: float
width: float
height: float
@dataclass
class CocoRecord:
filename: str
width: int
height: int
objects: List[CocoObject] = field(default_factory=list)
class CocoHandler:
@staticmethod
def load(filename, **kwargs) -> List[CocoRecord]:
json_data = khandy.load_json(filename)
images = json_data['images']
annotations = json_data['annotations']
categories = json_data['categories']
label_map = {}
for cat_item in categories:
label_map[cat_item['id']] = cat_item['name']
coco_records = OrderedDict()
for image_item in images:
coco_records[image_item['id']] = CocoRecord(
filename=image_item['file_name'],
width=image_item['width'],
height=image_item['height'],
objects=[])
for annotation_item in annotations:
coco_object = CocoObject(
label=label_map[annotation_item['category_id']],
x_min=annotation_item['bbox'][0],
y_min=annotation_item['bbox'][1],
width=annotation_item['bbox'][2],
height=annotation_item['bbox'][3])
coco_records[annotation_item['image_id']
].objects.append(coco_object)
return list(coco_records.values())
@staticmethod
def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
ir_record = DetectIrRecord(
filename=coco_record.filename,
width=coco_record.width,
height=coco_record.height,
)
for coco_object in coco_record.objects:
ir_object = DetectIrObject(
label=coco_object.label,
x_min=coco_object.x_min,
y_min=coco_object.y_min,
x_max=coco_object.x_min + coco_object.width,
y_max=coco_object.y_min + coco_object.height
)
ir_record.objects.append(ir_object)
return ir_record
@staticmethod
def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
coco_record = CocoRecord(
filename=ir_record.filename,
width=ir_record.width,
height=ir_record.height
)
for ir_object in ir_record.objects:
coco_object = CocoObject(
label=ir_object.label,
x_min=ir_object.x_min,
y_min=ir_object.y_min,
width=ir_object.x_max - ir_object.x_min,
height=ir_object.y_max - ir_object.y_min
)
coco_record.objects.append(coco_object)
return coco_record
def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
if fmt == 'labelme':
labelme_record = LabelmeHandler.load(filename, **kwargs)
ir_record = LabelmeHandler.to_ir(labelme_record)
elif fmt == 'yolo':
yolo_record = YoloHandler.load(filename, **kwargs)
ir_record = YoloHandler.to_ir(yolo_record)
elif fmt in ('voc', 'pascal', 'pascal_voc'):
pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
ir_record = PascalVocHandler.to_ir(pascal_voc_record)
elif fmt == 'coco':
coco_records = CocoHandler.load(filename, **kwargs)
ir_record = [CocoHandler.to_ir(coco_record)
for coco_record in coco_records]
else:
raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
return ir_record
def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
if out_fmt == 'labelme':
labelme_record = LabelmeHandler.from_ir(ir_record)
LabelmeHandler.save(filename, labelme_record)
elif out_fmt == 'yolo':
yolo_record = YoloHandler.from_ir(ir_record)
YoloHandler.save(filename, yolo_record)
elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
pascal_voc_record = PascalVocHandler.from_ir(ir_record)
PascalVocHandler.save(filename, pascal_voc_record)
elif out_fmt == 'coco':
raise ValueError("Unsupported for `coco` now!")
else:
raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
def _get_format(record):
if isinstance(record, LabelmeRecord):
return ('labelme',)
elif isinstance(record, YoloRecord):
return ('yolo',)
elif isinstance(record, PascalVocRecord):
return ('voc', 'pascal', 'pascal_voc')
elif isinstance(record, CocoRecord):
return ('coco',)
elif isinstance(record, DetectIrRecord):
return ('ir', 'detect_ir')
else:
return ()
def convert_detect(record, out_fmt):
allowed_fmts = ('labelme', 'yolo', 'voc', 'coco',
'pascal', 'pascal_voc', 'ir', 'detect_ir')
if out_fmt not in allowed_fmts:
raise ValueError(
"Unsupported label format conversions for given out_fmt")
if out_fmt in _get_format(record):
return record
if isinstance(record, LabelmeRecord):
ir_record = LabelmeHandler.to_ir(record)
elif isinstance(record, YoloRecord):
ir_record = YoloHandler.to_ir(record)
elif isinstance(record, PascalVocRecord):
ir_record = PascalVocHandler.to_ir(record)
elif isinstance(record, CocoRecord):
ir_record = CocoHandler.to_ir(record)
elif isinstance(record, DetectIrRecord):
ir_record = record
else:
raise TypeError('Unsupported type for record')
if out_fmt in ('ir', 'detect_ir'):
dst_record = ir_record
elif out_fmt == 'labelme':
dst_record = LabelmeHandler.from_ir(ir_record)
elif out_fmt == 'yolo':
dst_record = YoloHandler.from_ir(ir_record)
elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
dst_record = PascalVocHandler.from_ir(ir_record)
elif out_fmt == 'coco':
dst_record = CocoHandler.from_ir(ir_record)
return dst_record
def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
dst_record = copy.deepcopy(record)
dst_objects = []
for ir_object in dst_record.objects:
if not ignore:
if ir_object.label in label_map:
ir_object.label = label_map[ir_object.label]
dst_objects.append(ir_object)
else:
if ir_object.label in label_map:
ir_object.label = label_map[ir_object.label]
dst_objects.append(ir_object)
dst_record.objects = dst_objects
return dst_record
def load_coco_class_names(filename):
json_data = khandy.load_json(filename)
categories = json_data['categories']
return [cat_item['name'] for cat_item in categories]