Spaces:
Running
Running
import os | |
import copy | |
import json | |
import dataclasses | |
from dataclasses import dataclass, field | |
from collections import OrderedDict | |
from typing import Optional, List | |
import xml.etree.ElementTree as ET | |
import khandy | |
import lxml | |
import lxml.builder | |
import numpy as np | |
__all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect', | |
'save_detect', 'convert_detect', 'replace_detect_label', | |
'load_coco_class_names'] | |
class DetectIrObject: | |
"""Intermediate Representation Format of Object | |
""" | |
label: str | |
x_min: float | |
y_min: float | |
x_max: float | |
y_max: float | |
class DetectIrRecord: | |
"""Intermediate Representation Format of Record | |
""" | |
filename: str | |
width: int | |
height: int | |
objects: List[DetectIrObject] = field(default_factory=list) | |
class PascalVocSource: | |
database: str = '' | |
annotation: str = '' | |
image: str = '' | |
class PascalVocSize: | |
height: int | |
width: int | |
depth: int | |
class PascalVocBndbox: | |
xmin: float | |
ymin: float | |
xmax: float | |
ymax: float | |
class PascalVocObject: | |
name: str | |
pose: str = 'Unspecified' | |
truncated: int = 0 | |
difficult: int = 0 | |
bndbox: Optional[PascalVocBndbox] = None | |
class PascalVocRecord: | |
folder: str = '' | |
filename: str = '' | |
path: str = '' | |
source: PascalVocSource = PascalVocSource() | |
size: Optional[PascalVocSize] = None | |
segmented: int = 0 | |
objects: List[PascalVocObject] = field(default_factory=list) | |
class PascalVocHandler: | |
def load(filename, **kwargs) -> PascalVocRecord: | |
pascal_voc_record = PascalVocRecord() | |
xml_tree = ET.parse(filename) | |
pascal_voc_record.folder = xml_tree.find('folder').text | |
pascal_voc_record.filename = xml_tree.find('filename').text | |
pascal_voc_record.path = xml_tree.find('path').text | |
pascal_voc_record.segmented = xml_tree.find('segmented').text | |
source_tag = xml_tree.find('source') | |
pascal_voc_record.source = PascalVocSource( | |
database=source_tag.find('database').text, | |
# annotation=source_tag.find('annotation').text, | |
# image=source_tag.find('image').text | |
) | |
size_tag = xml_tree.find('size') | |
pascal_voc_record.size = PascalVocSize( | |
width=int(size_tag.find('width').text), | |
height=int(size_tag.find('height').text), | |
depth=int(size_tag.find('depth').text) | |
) | |
object_tags = xml_tree.findall('object') | |
for index, object_tag in enumerate(object_tags): | |
bndbox_tag = object_tag.find('bndbox') | |
bndbox = PascalVocBndbox( | |
xmin=float(bndbox_tag.find('xmin').text) - 1, | |
ymin=float(bndbox_tag.find('ymin').text) - 1, | |
xmax=float(bndbox_tag.find('xmax').text) - 1, | |
ymax=float(bndbox_tag.find('ymax').text) - 1 | |
) | |
pascal_voc_object = PascalVocObject( | |
name=object_tag.find('name').text, | |
pose=object_tag.find('pose').text, | |
truncated=object_tag.find('truncated').text, | |
difficult=object_tag.find('difficult').text, | |
bndbox=bndbox | |
) | |
pascal_voc_record.objects.append(pascal_voc_object) | |
return pascal_voc_record | |
def save(filename, pascal_voc_record: PascalVocRecord): | |
maker = lxml.builder.ElementMaker() | |
xml = maker.annotation( | |
maker.folder(pascal_voc_record.folder), | |
maker.filename(pascal_voc_record.filename), | |
maker.path(pascal_voc_record.path), | |
maker.source( | |
maker.database(pascal_voc_record.source.database), | |
), | |
maker.size( | |
maker.width(str(pascal_voc_record.size.width)), | |
maker.height(str(pascal_voc_record.size.height)), | |
maker.depth(str(pascal_voc_record.size.depth)), | |
), | |
maker.segmented(str(pascal_voc_record.segmented)), | |
) | |
for pascal_voc_object in pascal_voc_record.objects: | |
object_tag = maker.object( | |
maker.name(pascal_voc_object.name), | |
maker.pose(pascal_voc_object.pose), | |
maker.truncated(str(pascal_voc_object.truncated)), | |
maker.difficult(str(pascal_voc_object.difficult)), | |
maker.bndbox( | |
maker.xmin(str(float(pascal_voc_object.bndbox.xmin))), | |
maker.ymin(str(float(pascal_voc_object.bndbox.ymin))), | |
maker.xmax(str(float(pascal_voc_object.bndbox.xmax))), | |
maker.ymax(str(float(pascal_voc_object.bndbox.ymax))), | |
), | |
) | |
xml.append(object_tag) | |
if not filename.endswith('.xml'): | |
filename = filename + '.xml' | |
with open(filename, 'wb') as f: | |
f.write(lxml.etree.tostring( | |
xml, pretty_print=True, encoding='utf-8')) | |
def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord: | |
ir_record = DetectIrRecord( | |
filename=pascal_voc_record.filename, | |
width=pascal_voc_record.size.width, | |
height=pascal_voc_record.size.height | |
) | |
for pascal_voc_object in pascal_voc_record.objects: | |
ir_object = DetectIrObject( | |
label=pascal_voc_object.name, | |
x_min=pascal_voc_object.bndbox.xmin, | |
y_min=pascal_voc_object.bndbox.ymin, | |
x_max=pascal_voc_object.bndbox.xmax, | |
y_max=pascal_voc_object.bndbox.ymax | |
) | |
ir_record.objects.append(ir_object) | |
return ir_record | |
def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord: | |
pascal_voc_record = PascalVocRecord( | |
filename=ir_record.filename, | |
size=PascalVocSize( | |
width=ir_record.width, | |
height=ir_record.height, | |
depth=3 | |
) | |
) | |
for ir_object in ir_record.objects: | |
pascal_voc_object = PascalVocObject( | |
name=ir_object.label, | |
bndbox=PascalVocBndbox( | |
xmin=ir_object.x_min, | |
ymin=ir_object.y_min, | |
xmax=ir_object.x_max, | |
ymax=ir_object.y_max, | |
) | |
) | |
pascal_voc_record.objects.append(pascal_voc_object) | |
return pascal_voc_record | |
class _NumpyEncoder(json.JSONEncoder): | |
""" Special json encoder for numpy types """ | |
def default(self, obj): | |
if isinstance(obj, (np.bool_,)): | |
return bool(obj) | |
elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, | |
np.int16, np.int32, np.int64, np.uint8, | |
np.uint16, np.uint32, np.uint64)): | |
return int(obj) | |
elif isinstance(obj, (np.float_, np.float16, np.float32, | |
np.float64)): | |
return float(obj) | |
elif isinstance(obj, (np.ndarray,)): | |
return obj.tolist() | |
return json.JSONEncoder.default(self, obj) | |
class LabelmeShape: | |
label: str | |
points: np.ndarray | |
shape_type: str | |
flags: dict = field(default_factory=dict) | |
group_id: Optional[int] = None | |
def __post_init__(self): | |
self.points = np.asarray(self.points) | |
class LabelmeRecord: | |
version: str = '4.5.6' | |
flags: dict = field(default_factory=dict) | |
shapes: List[LabelmeShape] = field(default_factory=list) | |
imagePath: Optional[str] = None | |
imageData: Optional[str] = None | |
imageHeight: Optional[int] = None | |
imageWidth: Optional[int] = None | |
def __post_init__(self): | |
for k, shape in enumerate(self.shapes): | |
self.shapes[k] = LabelmeShape(**shape) | |
class LabelmeHandler: | |
def load(filename, **kwargs) -> LabelmeRecord: | |
json_content = khandy.load_json(filename) | |
return LabelmeRecord(**json_content) | |
def save(filename, labelme_record: LabelmeRecord): | |
json_content = dataclasses.asdict(labelme_record) | |
khandy.save_json(filename, json_content, cls=_NumpyEncoder) | |
def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord: | |
ir_record = DetectIrRecord( | |
filename=labelme_record.imagePath, | |
width=labelme_record.imageWidth, | |
height=labelme_record.imageHeight | |
) | |
for labelme_shape in labelme_record.shapes: | |
if labelme_shape.shape_type != 'rectangle': | |
continue | |
ir_object = DetectIrObject( | |
label=labelme_shape.label, | |
x_min=labelme_shape.points[0][0], | |
y_min=labelme_shape.points[0][1], | |
x_max=labelme_shape.points[1][0], | |
y_max=labelme_shape.points[1][1], | |
) | |
ir_record.objects.append(ir_object) | |
return ir_record | |
def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord: | |
labelme_record = LabelmeRecord( | |
imagePath=ir_record.filename, | |
imageWidth=ir_record.width, | |
imageHeight=ir_record.height | |
) | |
for ir_object in ir_record.objects: | |
labelme_shape = LabelmeShape( | |
label=ir_object.label, | |
shape_type='rectangle', | |
points=[[ir_object.x_min, ir_object.y_min], | |
[ir_object.x_max, ir_object.y_max]] | |
) | |
labelme_record.shapes.append(labelme_shape) | |
return labelme_record | |
class YoloObject: | |
label: str | |
x_center: float | |
y_center: float | |
width: float | |
height: float | |
class YoloRecord: | |
filename: Optional[str] = None | |
width: Optional[int] = None | |
height: Optional[int] = None | |
objects: List[YoloObject] = field(default_factory=list) | |
class YoloHandler: | |
def load(filename, **kwargs) -> YoloRecord: | |
assert 'image_filename' in kwargs | |
assert 'width' in kwargs and 'height' in kwargs | |
records = khandy.load_list(filename) | |
yolo_record = YoloRecord( | |
filename=kwargs.get('image_filename'), | |
width=kwargs.get('width'), | |
height=kwargs.get('height')) | |
for record in records: | |
record_parts = record.split() | |
yolo_record.objects.append(YoloObject( | |
label=record_parts[0], | |
x_center=float(record_parts[1]), | |
y_center=float(record_parts[2]), | |
width=float(record_parts[3]), | |
height=float(record_parts[4]), | |
)) | |
return yolo_record | |
def save(filename, yolo_record: YoloRecord): | |
records = [] | |
for object in yolo_record.objects: | |
records.append( | |
f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}') | |
if not filename.endswith('.txt'): | |
filename = filename + '.txt' | |
khandy.save_list(filename, records) | |
def to_ir(yolo_record: YoloRecord) -> DetectIrRecord: | |
ir_record = DetectIrRecord( | |
filename=yolo_record.filename, | |
width=yolo_record.width, | |
height=yolo_record.height | |
) | |
for yolo_object in yolo_record.objects: | |
x_min = (yolo_object.x_center - 0.5 * | |
yolo_object.width) * yolo_record.width | |
y_min = (yolo_object.y_center - 0.5 * | |
yolo_object.height) * yolo_record.height | |
x_max = (yolo_object.x_center + 0.5 * | |
yolo_object.width) * yolo_record.width | |
y_max = (yolo_object.y_center + 0.5 * | |
yolo_object.height) * yolo_record.height | |
ir_object = DetectIrObject( | |
label=yolo_object.label, | |
x_min=x_min, | |
y_min=y_min, | |
x_max=x_max, | |
y_max=y_max | |
) | |
ir_record.objects.append(ir_object) | |
return ir_record | |
def from_ir(ir_record: DetectIrRecord) -> YoloRecord: | |
yolo_record = YoloRecord( | |
filename=ir_record.filename, | |
width=ir_record.width, | |
height=ir_record.height | |
) | |
for ir_object in ir_record.objects: | |
x_center = (ir_object.x_max + ir_object.x_min) / \ | |
(2 * ir_record.width) | |
y_center = (ir_object.y_max + ir_object.y_min) / \ | |
(2 * ir_record.height) | |
width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width | |
height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height | |
yolo_object = YoloObject( | |
label=ir_object.label, | |
x_center=x_center, | |
y_center=y_center, | |
width=width, | |
height=height, | |
) | |
yolo_record.objects.append(yolo_object) | |
return yolo_record | |
class CocoObject: | |
label: str | |
x_min: float | |
y_min: float | |
width: float | |
height: float | |
class CocoRecord: | |
filename: str | |
width: int | |
height: int | |
objects: List[CocoObject] = field(default_factory=list) | |
class CocoHandler: | |
def load(filename, **kwargs) -> List[CocoRecord]: | |
json_data = khandy.load_json(filename) | |
images = json_data['images'] | |
annotations = json_data['annotations'] | |
categories = json_data['categories'] | |
label_map = {} | |
for cat_item in categories: | |
label_map[cat_item['id']] = cat_item['name'] | |
coco_records = OrderedDict() | |
for image_item in images: | |
coco_records[image_item['id']] = CocoRecord( | |
filename=image_item['file_name'], | |
width=image_item['width'], | |
height=image_item['height'], | |
objects=[]) | |
for annotation_item in annotations: | |
coco_object = CocoObject( | |
label=label_map[annotation_item['category_id']], | |
x_min=annotation_item['bbox'][0], | |
y_min=annotation_item['bbox'][1], | |
width=annotation_item['bbox'][2], | |
height=annotation_item['bbox'][3]) | |
coco_records[annotation_item['image_id'] | |
].objects.append(coco_object) | |
return list(coco_records.values()) | |
def to_ir(coco_record: CocoRecord) -> DetectIrRecord: | |
ir_record = DetectIrRecord( | |
filename=coco_record.filename, | |
width=coco_record.width, | |
height=coco_record.height, | |
) | |
for coco_object in coco_record.objects: | |
ir_object = DetectIrObject( | |
label=coco_object.label, | |
x_min=coco_object.x_min, | |
y_min=coco_object.y_min, | |
x_max=coco_object.x_min + coco_object.width, | |
y_max=coco_object.y_min + coco_object.height | |
) | |
ir_record.objects.append(ir_object) | |
return ir_record | |
def from_ir(ir_record: DetectIrRecord) -> CocoRecord: | |
coco_record = CocoRecord( | |
filename=ir_record.filename, | |
width=ir_record.width, | |
height=ir_record.height | |
) | |
for ir_object in ir_record.objects: | |
coco_object = CocoObject( | |
label=ir_object.label, | |
x_min=ir_object.x_min, | |
y_min=ir_object.y_min, | |
width=ir_object.x_max - ir_object.x_min, | |
height=ir_object.y_max - ir_object.y_min | |
) | |
coco_record.objects.append(coco_object) | |
return coco_record | |
def load_detect(filename, fmt, **kwargs) -> DetectIrRecord: | |
if fmt == 'labelme': | |
labelme_record = LabelmeHandler.load(filename, **kwargs) | |
ir_record = LabelmeHandler.to_ir(labelme_record) | |
elif fmt == 'yolo': | |
yolo_record = YoloHandler.load(filename, **kwargs) | |
ir_record = YoloHandler.to_ir(yolo_record) | |
elif fmt in ('voc', 'pascal', 'pascal_voc'): | |
pascal_voc_record = PascalVocHandler.load(filename, **kwargs) | |
ir_record = PascalVocHandler.to_ir(pascal_voc_record) | |
elif fmt == 'coco': | |
coco_records = CocoHandler.load(filename, **kwargs) | |
ir_record = [CocoHandler.to_ir(coco_record) | |
for coco_record in coco_records] | |
else: | |
raise ValueError(f"Unsupported detect label fmt. Got {fmt}") | |
return ir_record | |
def save_detect(filename, ir_record: DetectIrRecord, out_fmt): | |
os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True) | |
if out_fmt == 'labelme': | |
labelme_record = LabelmeHandler.from_ir(ir_record) | |
LabelmeHandler.save(filename, labelme_record) | |
elif out_fmt == 'yolo': | |
yolo_record = YoloHandler.from_ir(ir_record) | |
YoloHandler.save(filename, yolo_record) | |
elif out_fmt in ('voc', 'pascal', 'pascal_voc'): | |
pascal_voc_record = PascalVocHandler.from_ir(ir_record) | |
PascalVocHandler.save(filename, pascal_voc_record) | |
elif out_fmt == 'coco': | |
raise ValueError("Unsupported for `coco` now!") | |
else: | |
raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}") | |
def _get_format(record): | |
if isinstance(record, LabelmeRecord): | |
return ('labelme',) | |
elif isinstance(record, YoloRecord): | |
return ('yolo',) | |
elif isinstance(record, PascalVocRecord): | |
return ('voc', 'pascal', 'pascal_voc') | |
elif isinstance(record, CocoRecord): | |
return ('coco',) | |
elif isinstance(record, DetectIrRecord): | |
return ('ir', 'detect_ir') | |
else: | |
return () | |
def convert_detect(record, out_fmt): | |
allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', | |
'pascal', 'pascal_voc', 'ir', 'detect_ir') | |
if out_fmt not in allowed_fmts: | |
raise ValueError( | |
"Unsupported label format conversions for given out_fmt") | |
if out_fmt in _get_format(record): | |
return record | |
if isinstance(record, LabelmeRecord): | |
ir_record = LabelmeHandler.to_ir(record) | |
elif isinstance(record, YoloRecord): | |
ir_record = YoloHandler.to_ir(record) | |
elif isinstance(record, PascalVocRecord): | |
ir_record = PascalVocHandler.to_ir(record) | |
elif isinstance(record, CocoRecord): | |
ir_record = CocoHandler.to_ir(record) | |
elif isinstance(record, DetectIrRecord): | |
ir_record = record | |
else: | |
raise TypeError('Unsupported type for record') | |
if out_fmt in ('ir', 'detect_ir'): | |
dst_record = ir_record | |
elif out_fmt == 'labelme': | |
dst_record = LabelmeHandler.from_ir(ir_record) | |
elif out_fmt == 'yolo': | |
dst_record = YoloHandler.from_ir(ir_record) | |
elif out_fmt in ('voc', 'pascal', 'pascal_voc'): | |
dst_record = PascalVocHandler.from_ir(ir_record) | |
elif out_fmt == 'coco': | |
dst_record = CocoHandler.from_ir(ir_record) | |
return dst_record | |
def replace_detect_label(record: DetectIrRecord, label_map, ignore=True): | |
dst_record = copy.deepcopy(record) | |
dst_objects = [] | |
for ir_object in dst_record.objects: | |
if not ignore: | |
if ir_object.label in label_map: | |
ir_object.label = label_map[ir_object.label] | |
dst_objects.append(ir_object) | |
else: | |
if ir_object.label in label_map: | |
ir_object.label = label_map[ir_object.label] | |
dst_objects.append(ir_object) | |
dst_record.objects = dst_objects | |
return dst_record | |
def load_coco_class_names(filename): | |
json_data = khandy.load_json(filename) | |
categories = json_data['categories'] | |
return [cat_item['name'] for cat_item in categories] | |