import os import os.path as op import json # import logging import base64 import yaml import errno import io import math from PIL import Image, ImageDraw from maskrcnn_benchmark.structures.bounding_box import BoxList from .box_label_loader import LabelLoader def load_linelist_file(linelist_file): if linelist_file is not None: line_list = [] with open(linelist_file, "r") as fp: for i in fp: line_list.append(int(i.strip())) return line_list def img_from_base64(imagestring): try: img = Image.open(io.BytesIO(base64.b64decode(imagestring))) return img.convert("RGB") except ValueError: return None def load_from_yaml_file(yaml_file): with open(yaml_file, "r") as fp: return yaml.load(fp, Loader=yaml.CLoader) def find_file_path_in_yaml(fname, root): if fname is not None: found_file = None if op.isfile(fname): found_file = fname elif op.isfile(op.join(root, fname)): found_file = op.join(root, fname) else: # be a bit more flexible and try to find the file in the root recursively try_time = 3 while try_time > 0: try_time -= 1 root = os.path.dirname(root) if op.isfile(op.join(root, fname)): found_file = op.join(root, fname) break if found_file is None: raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) ) print('found file: {}'.format(found_file)) return found_file def create_lineidx(filein, idxout): idxout_tmp = idxout + ".tmp" with open(filein, "r") as tsvin, open(idxout_tmp, "w") as tsvout: fsize = os.fstat(tsvin.fileno()).st_size fpos = 0 while fpos != fsize: tsvout.write(str(fpos) + "\n") tsvin.readline() fpos = tsvin.tell() os.rename(idxout_tmp, idxout) def read_to_character(fp, c): result = [] while True: s = fp.read(32) assert s != "" if c in s: result.append(s[: s.index(c)]) break else: result.append(s) return "".join(result) class TSVFile(object): def __init__(self, tsv_file, generate_lineidx=False): self.tsv_file = tsv_file self.lineidx = op.splitext(tsv_file)[0] + ".lineidx" self._fp = None self._lineidx = None # the process always keeps the process which opens the file. # If the pid is not equal to the currrent pid, we will re-open the file. self.pid = None # generate lineidx if not exist if not op.isfile(self.lineidx) and generate_lineidx: create_lineidx(self.tsv_file, self.lineidx) def __del__(self): if self._fp: self._fp.close() def __str__(self): return "TSVFile(tsv_file='{}')".format(self.tsv_file) def __repr__(self): return str(self) def num_rows(self): self._ensure_lineidx_loaded() return len(self._lineidx) def seek(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() try: pos = self._lineidx[idx] except: # logging.info('{}-{}'.format(self.tsv_file, idx)) raise self._fp.seek(pos) return [s.strip() for s in self._fp.readline().split("\t")] def seek_first_column(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() pos = self._lineidx[idx] self._fp.seek(pos) return read_to_character(self._fp, "\t") def get_key(self, idx): return self.seek_first_column(idx) def __getitem__(self, index): return self.seek(index) def __len__(self): return self.num_rows() def _ensure_lineidx_loaded(self): if self._lineidx is None: # logging.info('loading lineidx: {}'.format(self.lineidx)) with open(self.lineidx, "r") as fp: self._lineidx = [int(i.strip()) for i in fp.readlines()] def _ensure_tsv_opened(self): if self._fp is None: self._fp = open(self.tsv_file, "r") self.pid = os.getpid() if self.pid != os.getpid(): # logging.info('re-open {} because the process id changed'.format(self.tsv_file)) self._fp = open(self.tsv_file, "r") self.pid = os.getpid() class CompositeTSVFile: def __init__(self, file_list, seq_file, root="."): if isinstance(file_list, str): self.file_list = load_list_file(file_list) else: assert isinstance(file_list, list) self.file_list = file_list self.seq_file = seq_file self.root = root self.initialized = False self.initialize() def get_key(self, index): idx_source, idx_row = self.seq[index] k = self.tsvs[idx_source].get_key(idx_row) return "_".join([self.file_list[idx_source], k]) def num_rows(self): return len(self.seq) def __getitem__(self, index): idx_source, idx_row = self.seq[index] return self.tsvs[idx_source].seek(idx_row) def __len__(self): return len(self.seq) def initialize(self): """ this function has to be called in init function if cache_policy is enabled. Thus, let's always call it in init funciton to make it simple. """ if self.initialized: return self.seq = [] with open(self.seq_file, "r") as fp: for line in fp: parts = line.strip().split("\t") self.seq.append([int(parts[0]), int(parts[1])]) self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] self.initialized = True def load_list_file(fname): with open(fname, "r") as fp: lines = fp.readlines() result = [line.strip() for line in lines] if len(result) > 0 and result[-1] == "": result = result[:-1] return result class TSVDataset(object): def __init__(self, img_file, label_file=None, hw_file=None, linelist_file=None, imageid2idx_file=None): """Constructor. Args: img_file: Image file with image key and base64 encoded image str. label_file: An optional label file with image key and label information. A label_file is required for training and optional for testing. hw_file: An optional file with image key and image height/width info. linelist_file: An optional file with a list of line indexes to load samples. It is useful to select a subset of samples or duplicate samples. """ self.img_file = img_file self.label_file = label_file self.hw_file = hw_file self.linelist_file = linelist_file self.img_tsv = TSVFile(img_file) self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True) self.hw_tsv = None if hw_file is None else TSVFile(hw_file) self.line_list = load_linelist_file(linelist_file) self.imageid2idx = None if imageid2idx_file is not None: self.imageid2idx = json.load(open(imageid2idx_file, "r")) self.transforms = None def __len__(self): if self.line_list is None: if self.imageid2idx is not None: assert self.label_tsv is not None, "label_tsv is None!!!" return self.label_tsv.num_rows() return self.img_tsv.num_rows() else: return len(self.line_list) def __getitem__(self, idx): img = self.get_image(idx) img_size = img.size # w, h annotations = self.get_annotations(idx) # print(idx, annotations) target = self.get_target_from_annotations(annotations, img_size, idx) img, target = self.apply_transforms(img, target) if self.transforms is None: return img, target, idx, 1.0 else: new_img_size = img.shape[1:] scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1])) return img, target, idx, scale def get_line_no(self, idx): return idx if self.line_list is None else self.line_list[idx] def get_image(self, idx): line_no = self.get_line_no(idx) if self.imageid2idx is not None: assert self.label_tsv is not None, "label_tsv is None!!!" row = self.label_tsv.seek(line_no) annotations = json.loads(row[1]) imageid = annotations["img_id"] line_no = self.imageid2idx[imageid] row = self.img_tsv.seek(line_no) # use -1 to support old format with multiple columns. img = img_from_base64(row[-1]) return img def get_annotations(self, idx): line_no = self.get_line_no(idx) if self.label_tsv is not None: row = self.label_tsv.seek(line_no) annotations = json.loads(row[1]) return annotations else: return [] def get_target_from_annotations(self, annotations, img_size, idx): # This function will be overwritten by each dataset to # decode the labels to specific formats for each task. return annotations def apply_transforms(self, image, target=None): # This function will be overwritten by each dataset to # apply transforms to image and targets. return image, target def get_img_info(self, idx): if self.imageid2idx is not None: assert self.label_tsv is not None, "label_tsv is None!!!" line_no = self.get_line_no(idx) row = self.label_tsv.seek(line_no) annotations = json.loads(row[1]) return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])} if self.hw_tsv is not None: line_no = self.get_line_no(idx) row = self.hw_tsv.seek(line_no) try: # json string format with "height" and "width" being the keys data = json.loads(row[1]) if type(data) == list: return data[0] elif type(data) == dict: return data except ValueError: # list of strings representing height and width in order hw_str = row[1].split(" ") hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} return hw_dict def get_img_key(self, idx): line_no = self.get_line_no(idx) # based on the overhead of reading each row. if self.imageid2idx is not None: assert self.label_tsv is not None, "label_tsv is None!!!" row = self.label_tsv.seek(line_no) annotations = json.loads(row[1]) return annotations["img_id"] if self.hw_tsv: return self.hw_tsv.seek(line_no)[0] elif self.label_tsv: return self.label_tsv.seek(line_no)[0] else: return self.img_tsv.seek(line_no)[0] class TSVYamlDataset(TSVDataset): """TSVDataset taking a Yaml file for easy function call""" def __init__(self, yaml_file, root=None, replace_clean_label=False): print("Reading {}".format(yaml_file)) self.cfg = load_from_yaml_file(yaml_file) if root: self.root = root else: self.root = op.dirname(yaml_file) img_file = find_file_path_in_yaml(self.cfg["img"], self.root) label_file = find_file_path_in_yaml(self.cfg.get("label", None), self.root) hw_file = find_file_path_in_yaml(self.cfg.get("hw", None), self.root) linelist_file = find_file_path_in_yaml(self.cfg.get("linelist", None), self.root) imageid2idx_file = find_file_path_in_yaml(self.cfg.get("imageid2idx", None), self.root) if replace_clean_label: assert "raw_label" in label_file label_file = label_file.replace("raw_label", "clean_label") super(TSVYamlDataset, self).__init__(img_file, label_file, hw_file, linelist_file, imageid2idx_file) class ODTSVDataset(TSVYamlDataset): """ Generic TSV dataset format for Object Detection. """ def __init__(self, yaml_file, extra_fields=(), transforms=None, is_load_label=True, **kwargs): if yaml_file is None: return super(ODTSVDataset, self).__init__(yaml_file) self.transforms = transforms self.is_load_label = is_load_label self.attribute_on = False # self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False if self.is_load_label: # construct maps jsondict_file = find_file_path_in_yaml(self.cfg.get("labelmap", None), self.root) if jsondict_file is None: jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root) if "json" in jsondict_file: jsondict = json.load(open(jsondict_file, "r")) if "label_to_idx" not in jsondict: jsondict = {"label_to_idx": jsondict} elif "tsv" in jsondict_file: label_to_idx = {} counter = 1 with open(jsondict_file) as f: for line in f: label_to_idx[line.strip()] = counter counter += 1 jsondict = {"label_to_idx": label_to_idx} else: assert 0 self.labelmap = {} self.class_to_ind = jsondict["label_to_idx"] self.class_to_ind["__background__"] = 0 self.ind_to_class = {v: k for k, v in self.class_to_ind.items()} self.labelmap["class_to_ind"] = self.class_to_ind if self.attribute_on: self.attribute_to_ind = jsondict["attribute_to_idx"] self.attribute_to_ind["__no_attribute__"] = 0 self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()} self.labelmap["attribute_to_ind"] = self.attribute_to_ind self.label_loader = LabelLoader( labelmap=self.labelmap, extra_fields=extra_fields, ) def get_target_from_annotations(self, annotations, img_size, idx): if isinstance(annotations, list): annotations = {"objects": annotations} if self.is_load_label: return self.label_loader(annotations["objects"], img_size) def apply_transforms(self, img, target=None): if self.transforms is not None: img, target = self.transforms(img, target) return img, target