zdou0830's picture
desco
749745d
raw
history blame
No virus
15.4 kB
import os
import os.path as op
import json
# import logging
import base64
import yaml
import errno
import io
import math
from PIL import Image, ImageDraw
from maskrcnn_benchmark.structures.bounding_box import BoxList
from .box_label_loader import LabelLoader
def load_linelist_file(linelist_file):
if linelist_file is not None:
line_list = []
with open(linelist_file, "r") as fp:
for i in fp:
line_list.append(int(i.strip()))
return line_list
def img_from_base64(imagestring):
try:
img = Image.open(io.BytesIO(base64.b64decode(imagestring)))
return img.convert("RGB")
except ValueError:
return None
def load_from_yaml_file(yaml_file):
with open(yaml_file, "r") as fp:
return yaml.load(fp, Loader=yaml.CLoader)
def find_file_path_in_yaml(fname, root):
if fname is not None:
found_file = None
if op.isfile(fname):
found_file = fname
elif op.isfile(op.join(root, fname)):
found_file = op.join(root, fname)
else:
# be a bit more flexible and try to find the file in the root recursively
try_time = 3
while try_time > 0:
try_time -= 1
root = os.path.dirname(root)
if op.isfile(op.join(root, fname)):
found_file = op.join(root, fname)
break
if found_file is None:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
)
print('found file: {}'.format(found_file))
return found_file
def create_lineidx(filein, idxout):
idxout_tmp = idxout + ".tmp"
with open(filein, "r") as tsvin, open(idxout_tmp, "w") as tsvout:
fsize = os.fstat(tsvin.fileno()).st_size
fpos = 0
while fpos != fsize:
tsvout.write(str(fpos) + "\n")
tsvin.readline()
fpos = tsvin.tell()
os.rename(idxout_tmp, idxout)
def read_to_character(fp, c):
result = []
while True:
s = fp.read(32)
assert s != ""
if c in s:
result.append(s[: s.index(c)])
break
else:
result.append(s)
return "".join(result)
class TSVFile(object):
def __init__(self, tsv_file, generate_lineidx=False):
self.tsv_file = tsv_file
self.lineidx = op.splitext(tsv_file)[0] + ".lineidx"
self._fp = None
self._lineidx = None
# the process always keeps the process which opens the file.
# If the pid is not equal to the currrent pid, we will re-open the file.
self.pid = None
# generate lineidx if not exist
if not op.isfile(self.lineidx) and generate_lineidx:
create_lineidx(self.tsv_file, self.lineidx)
def __del__(self):
if self._fp:
self._fp.close()
def __str__(self):
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
def __repr__(self):
return str(self)
def num_rows(self):
self._ensure_lineidx_loaded()
return len(self._lineidx)
def seek(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
try:
pos = self._lineidx[idx]
except:
# logging.info('{}-{}'.format(self.tsv_file, idx))
raise
self._fp.seek(pos)
return [s.strip() for s in self._fp.readline().split("\t")]
def seek_first_column(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
pos = self._lineidx[idx]
self._fp.seek(pos)
return read_to_character(self._fp, "\t")
def get_key(self, idx):
return self.seek_first_column(idx)
def __getitem__(self, index):
return self.seek(index)
def __len__(self):
return self.num_rows()
def _ensure_lineidx_loaded(self):
if self._lineidx is None:
# logging.info('loading lineidx: {}'.format(self.lineidx))
with open(self.lineidx, "r") as fp:
self._lineidx = [int(i.strip()) for i in fp.readlines()]
def _ensure_tsv_opened(self):
if self._fp is None:
self._fp = open(self.tsv_file, "r")
self.pid = os.getpid()
if self.pid != os.getpid():
# logging.info('re-open {} because the process id changed'.format(self.tsv_file))
self._fp = open(self.tsv_file, "r")
self.pid = os.getpid()
class CompositeTSVFile:
def __init__(self, file_list, seq_file, root="."):
if isinstance(file_list, str):
self.file_list = load_list_file(file_list)
else:
assert isinstance(file_list, list)
self.file_list = file_list
self.seq_file = seq_file
self.root = root
self.initialized = False
self.initialize()
def get_key(self, index):
idx_source, idx_row = self.seq[index]
k = self.tsvs[idx_source].get_key(idx_row)
return "_".join([self.file_list[idx_source], k])
def num_rows(self):
return len(self.seq)
def __getitem__(self, index):
idx_source, idx_row = self.seq[index]
return self.tsvs[idx_source].seek(idx_row)
def __len__(self):
return len(self.seq)
def initialize(self):
"""
this function has to be called in init function if cache_policy is
enabled. Thus, let's always call it in init funciton to make it simple.
"""
if self.initialized:
return
self.seq = []
with open(self.seq_file, "r") as fp:
for line in fp:
parts = line.strip().split("\t")
self.seq.append([int(parts[0]), int(parts[1])])
self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
self.initialized = True
def load_list_file(fname):
with open(fname, "r") as fp:
lines = fp.readlines()
result = [line.strip() for line in lines]
if len(result) > 0 and result[-1] == "":
result = result[:-1]
return result
class TSVDataset(object):
def __init__(self, img_file, label_file=None, hw_file=None, linelist_file=None, imageid2idx_file=None):
"""Constructor.
Args:
img_file: Image file with image key and base64 encoded image str.
label_file: An optional label file with image key and label information.
A label_file is required for training and optional for testing.
hw_file: An optional file with image key and image height/width info.
linelist_file: An optional file with a list of line indexes to load samples.
It is useful to select a subset of samples or duplicate samples.
"""
self.img_file = img_file
self.label_file = label_file
self.hw_file = hw_file
self.linelist_file = linelist_file
self.img_tsv = TSVFile(img_file)
self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True)
self.hw_tsv = None if hw_file is None else TSVFile(hw_file)
self.line_list = load_linelist_file(linelist_file)
self.imageid2idx = None
if imageid2idx_file is not None:
self.imageid2idx = json.load(open(imageid2idx_file, "r"))
self.transforms = None
def __len__(self):
if self.line_list is None:
if self.imageid2idx is not None:
assert self.label_tsv is not None, "label_tsv is None!!!"
return self.label_tsv.num_rows()
return self.img_tsv.num_rows()
else:
return len(self.line_list)
def __getitem__(self, idx):
img = self.get_image(idx)
img_size = img.size # w, h
annotations = self.get_annotations(idx)
# print(idx, annotations)
target = self.get_target_from_annotations(annotations, img_size, idx)
img, target = self.apply_transforms(img, target)
if self.transforms is None:
return img, target, idx, 1.0
else:
new_img_size = img.shape[1:]
scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1]))
return img, target, idx, scale
def get_line_no(self, idx):
return idx if self.line_list is None else self.line_list[idx]
def get_image(self, idx):
line_no = self.get_line_no(idx)
if self.imageid2idx is not None:
assert self.label_tsv is not None, "label_tsv is None!!!"
row = self.label_tsv.seek(line_no)
annotations = json.loads(row[1])
imageid = annotations["img_id"]
line_no = self.imageid2idx[imageid]
row = self.img_tsv.seek(line_no)
# use -1 to support old format with multiple columns.
img = img_from_base64(row[-1])
return img
def get_annotations(self, idx):
line_no = self.get_line_no(idx)
if self.label_tsv is not None:
row = self.label_tsv.seek(line_no)
annotations = json.loads(row[1])
return annotations
else:
return []
def get_target_from_annotations(self, annotations, img_size, idx):
# This function will be overwritten by each dataset to
# decode the labels to specific formats for each task.
return annotations
def apply_transforms(self, image, target=None):
# This function will be overwritten by each dataset to
# apply transforms to image and targets.
return image, target
def get_img_info(self, idx):
if self.imageid2idx is not None:
assert self.label_tsv is not None, "label_tsv is None!!!"
line_no = self.get_line_no(idx)
row = self.label_tsv.seek(line_no)
annotations = json.loads(row[1])
return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])}
if self.hw_tsv is not None:
line_no = self.get_line_no(idx)
row = self.hw_tsv.seek(line_no)
try:
# json string format with "height" and "width" being the keys
data = json.loads(row[1])
if type(data) == list:
return data[0]
elif type(data) == dict:
return data
except ValueError:
# list of strings representing height and width in order
hw_str = row[1].split(" ")
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
return hw_dict
def get_img_key(self, idx):
line_no = self.get_line_no(idx)
# based on the overhead of reading each row.
if self.imageid2idx is not None:
assert self.label_tsv is not None, "label_tsv is None!!!"
row = self.label_tsv.seek(line_no)
annotations = json.loads(row[1])
return annotations["img_id"]
if self.hw_tsv:
return self.hw_tsv.seek(line_no)[0]
elif self.label_tsv:
return self.label_tsv.seek(line_no)[0]
else:
return self.img_tsv.seek(line_no)[0]
class TSVYamlDataset(TSVDataset):
"""TSVDataset taking a Yaml file for easy function call"""
def __init__(self, yaml_file, root=None, replace_clean_label=False):
print("Reading {}".format(yaml_file))
self.cfg = load_from_yaml_file(yaml_file)
if root:
self.root = root
else:
self.root = op.dirname(yaml_file)
img_file = find_file_path_in_yaml(self.cfg["img"], self.root)
label_file = find_file_path_in_yaml(self.cfg.get("label", None), self.root)
hw_file = find_file_path_in_yaml(self.cfg.get("hw", None), self.root)
linelist_file = find_file_path_in_yaml(self.cfg.get("linelist", None), self.root)
imageid2idx_file = find_file_path_in_yaml(self.cfg.get("imageid2idx", None), self.root)
if replace_clean_label:
assert "raw_label" in label_file
label_file = label_file.replace("raw_label", "clean_label")
super(TSVYamlDataset, self).__init__(img_file, label_file, hw_file, linelist_file, imageid2idx_file)
class ODTSVDataset(TSVYamlDataset):
"""
Generic TSV dataset format for Object Detection.
"""
def __init__(self, yaml_file, extra_fields=(), transforms=None, is_load_label=True, **kwargs):
if yaml_file is None:
return
super(ODTSVDataset, self).__init__(yaml_file)
self.transforms = transforms
self.is_load_label = is_load_label
self.attribute_on = False
# self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False
if self.is_load_label:
# construct maps
jsondict_file = find_file_path_in_yaml(self.cfg.get("labelmap", None), self.root)
if jsondict_file is None:
jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
if "json" in jsondict_file:
jsondict = json.load(open(jsondict_file, "r"))
if "label_to_idx" not in jsondict:
jsondict = {"label_to_idx": jsondict}
elif "tsv" in jsondict_file:
label_to_idx = {}
counter = 1
with open(jsondict_file) as f:
for line in f:
label_to_idx[line.strip()] = counter
counter += 1
jsondict = {"label_to_idx": label_to_idx}
else:
assert 0
self.labelmap = {}
self.class_to_ind = jsondict["label_to_idx"]
self.class_to_ind["__background__"] = 0
self.ind_to_class = {v: k for k, v in self.class_to_ind.items()}
self.labelmap["class_to_ind"] = self.class_to_ind
if self.attribute_on:
self.attribute_to_ind = jsondict["attribute_to_idx"]
self.attribute_to_ind["__no_attribute__"] = 0
self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()}
self.labelmap["attribute_to_ind"] = self.attribute_to_ind
self.label_loader = LabelLoader(
labelmap=self.labelmap,
extra_fields=extra_fields,
)
def get_target_from_annotations(self, annotations, img_size, idx):
if isinstance(annotations, list):
annotations = {"objects": annotations}
if self.is_load_label:
return self.label_loader(annotations["objects"], img_size)
def apply_transforms(self, img, target=None):
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target