# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Simple dataset class that wraps a list of path names """ import os import numpy as np import torch from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.segmentation_mask import ( CharPolygons, SegmentationCharMask, SegmentationMask, ) from PIL import Image, ImageDraw class TotaltextDataset(object): def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False): self.use_charann = use_charann self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)] self.gts_dir = gts_dir self.transforms = transforms self.min_proposal_size = 2 self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz" self.vis = False self.ignore_difficult = ignore_difficult if self.ignore_difficult and (self.gts_dir is not None) and 'train' in self.gts_dir: self.image_lists = self.filter_image_lists() def filter_image_lists(self): new_image_lists = [] for img_path in self.image_lists: has_positive = False im_name = os.path.basename(img_path) gt_path = os.path.join(self.gts_dir, im_name + ".txt") if not os.path.isfile(gt_path): gt_path = os.path.join( self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt" ) lines = open(gt_path, 'r').readlines() for line in lines: charbbs = [] strs, loc = self.line2boxes(line) word = strs[0] if word == "###": continue else: has_positive = True if has_positive: new_image_lists.append(img_path) return new_image_lists def __getitem__(self, item): im_name = os.path.basename(self.image_lists[item]) # print(self.image_lists[item]) img = Image.open(self.image_lists[item]).convert("RGB") width, height = img.size if self.gts_dir is not None: gt_path = os.path.join(self.gts_dir, im_name + ".txt") words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt( gt_path, height, width ) if words[0] == "": use_char_ann = False else: use_char_ann = True if not self.use_charann: use_char_ann = False target = BoxList( boxes[:, :4], img.size, mode="xyxy", use_char_ann=use_char_ann ) if self.ignore_difficult: labels = torch.from_numpy(np.array(labels)) else: labels = torch.ones(len(boxes)) target.add_field("labels", labels) masks = SegmentationMask(segmentations, img.size) target.add_field("masks", masks) char_masks = SegmentationCharMask( charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes) ) target.add_field("char_masks", char_masks) else: target = None if self.transforms is not None: img, target = self.transforms(img, target) if self.vis: new_im = img.numpy().copy().transpose([1, 2, 0]) + [ 102.9801, 115.9465, 122.7717, ] new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB") mask = target.extra_fields["masks"].polygons[0].convert("mask") mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB") if self.use_charann: m, _ = ( target.extra_fields["char_masks"] .chars_boxes[0] .convert("char_mask") ) color = self.creat_color_map(37, 255) color_map = color[m.numpy().astype(np.uint8)] char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB") char = Image.blend(char, new_im, 0.5) else: char = new_im new = Image.blend(char, mask, 0.5) img_draw = ImageDraw.Draw(new) for box in target.bbox.numpy(): box = list(box) box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2] img_draw.line(box, fill=(255, 0, 0), width=2) new.save("./vis/char_" + im_name) return img, target, self.image_lists[item] def creat_color_map(self, n_class, width): splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3))) maps = [] for i in range(splits): r = int(i * width * 1.0 / (splits - 1)) for j in range(splits): g = int(j * width * 1.0 / (splits - 1)) for k in range(splits - 1): b = int(k * width * 1.0 / (splits - 1)) maps.append([r, g, b]) return np.array(maps) def __len__(self): return len(self.image_lists) # def load_gt_from_txt(self, gt_path, height=None, width=None): # words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] # lines = open(gt_path).readlines() # for line in lines: # charbbs = [] # strs, loc = self.line2boxes(line) # word = strs[0] # if word == "###": # labels.append(-1) # continue # else: # labels.append(1) # rect = list(loc[0]) # min_x = min(rect[::2]) - 1 # min_y = min(rect[1::2]) - 1 # max_x = max(rect[::2]) - 1 # max_y = max(rect[1::2]) - 1 # box = [min_x, min_y, max_x, max_y] # segmentations.append([loc[0, :]]) # tindex = len(boxes) # boxes.append(box) # words.append(word) # c_class = self.char2num(strs[1:]) # charbb = np.zeros((10,), dtype=np.float32) # if loc.shape[0] > 1: # for i in range(1, loc.shape[0]): # charbb[:8] = loc[i, :] # charbb[8] = c_class[i - 1] # charbb[9] = tindex # charbbs.append(charbb.copy()) # charsboxes.append(charbbs) # num_boxes = len(boxes) # if len(boxes) > 0: # keep_boxes = np.zeros((num_boxes, 5)) # keep_boxes[:, :4] = np.array(boxes) # keep_boxes[:, 4] = range( # num_boxes # ) # the 5th column is the box label,same as the 10th column of all charsboxes which belong to the box # if self.use_charann: # return words, np.array(keep_boxes), charsboxes, segmentations, labels # else: # charbbs = np.zeros((10,), dtype=np.float32) # for i in range(len(words)): # charsboxes.append([charbbs]) # return words, np.array(keep_boxes), charsboxes, segmentations, labels # else: # words.append("") # charbbs = np.zeros((10,), dtype=np.float32) # return ( # words, # np.zeros((1, 5), dtype=np.float32), # [[charbbs]], # [[np.zeros((8,), dtype=np.float32)]], # labels # ) def load_gt_from_txt(self, gt_path, height=None, width=None): words, boxes, charsboxes, segmentations, labels = [], [], [], [], [] lines = open(gt_path).readlines() for line in lines: charbbs = [] strs, loc = self.line2boxes(line) word = strs[0] if word == "###": if self.ignore_difficult: rect = list(loc[0]) min_x = min(rect[::2]) - 1 min_y = min(rect[1::2]) - 1 max_x = max(rect[::2]) - 1 max_y = max(rect[1::2]) - 1 box = [min_x, min_y, max_x, max_y] # segmentations.append([loc[0, :]]) segmentations.append([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]]) tindex = len(boxes) boxes.append(box) words.append(word) labels.append(-1) charbbs = np.zeros((10,), dtype=np.float32) if loc.shape[0] > 1: for i in range(1, loc.shape[0]): charbb[9] = tindex charbbs.append(charbb.copy()) charsboxes.append(charbbs) else: continue else: rect = list(loc[0]) min_x = min(rect[::2]) - 1 min_y = min(rect[1::2]) - 1 max_x = max(rect[::2]) - 1 max_y = max(rect[1::2]) - 1 box = [min_x, min_y, max_x, max_y] segmentations.append([loc[0, :]]) tindex = len(boxes) boxes.append(box) words.append(word) labels.append(1) c_class = self.char2num(strs[1:]) charbb = np.zeros((10,), dtype=np.float32) if loc.shape[0] > 1: for i in range(1, loc.shape[0]): charbb[:8] = loc[i, :] charbb[8] = c_class[i - 1] charbb[9] = tindex charbbs.append(charbb.copy()) charsboxes.append(charbbs) num_boxes = len(boxes) if len(boxes) > 0: keep_boxes = np.zeros((num_boxes, 5)) keep_boxes[:, :4] = np.array(boxes) keep_boxes[:, 4] = range( num_boxes ) # the 5th column is the box label, # same as the 10th column of all charsboxes which belong to the box if self.use_charann: return words, np.array(keep_boxes), charsboxes, segmentations, labels else: charbbs = np.zeros((10,), dtype=np.float32) if len(charsboxes) == 0: for _ in range(len(words)): charsboxes.append([charbbs]) return words, np.array(keep_boxes), charsboxes, segmentations, labels else: words.append("") charbbs = np.zeros((10,), dtype=np.float32) return ( words, np.zeros((1, 5), dtype=np.float32), [[charbbs]], [[np.zeros((8,), dtype=np.float32)]], [1] ) def line2boxes(self, line): parts = line.strip().split(",") return [parts[-1]], np.array([[float(x) for x in parts[:-1]]]) def check_charbbs(self, charbbs): xmins = np.minimum.reduce( [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] ) xmaxs = np.maximum.reduce( [charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]] ) ymins = np.minimum.reduce( [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] ) ymaxs = np.maximum.reduce( [charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]] ) return np.logical_and( xmaxs - xmins > self.min_proposal_size, ymaxs - ymins > self.min_proposal_size, ) def check_charbb(self, charbb): xmins = min(charbb[0], charbb[2], charbb[4], charbb[6]) xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6]) ymins = min(charbb[1], charbb[3], charbb[5], charbb[7]) ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7]) return ( xmaxs - xmins > self.min_proposal_size and ymaxs - ymins > self.min_proposal_size ) def char2num(self, chars): ## chars ['h', 'e', 'l', 'l', 'o'] nums = [self.char_classes.index(c.lower()) for c in chars] return nums def get_img_info(self, item): """ Return the image dimensions for the image, without loading and pre-processing it """ im_name = os.path.basename(self.image_lists[item]) img = Image.open(self.image_lists[item]) width, height = img.size img_info = {"im_name": im_name, "height": height, "width": width} return img_info