3v324v23's picture
add
c310e19
raw
history blame
11.9 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Simple dataset class that wraps a list of path names
"""
import os
import numpy as np
import torch
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import (
CharPolygons,
SegmentationCharMask,
SegmentationMask,
)
from PIL import Image, ImageDraw
class Tdtr(object):
def __init__(self, use_charann, imgs_dir, gts_dir, transforms=None, ignore_difficult=False):
self.use_charann = use_charann
self.image_lists = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)]
self.gts_dir = gts_dir
self.transforms = transforms
self.min_proposal_size = 2
self.char_classes = "_0123456789abcdefghijklmnopqrstuvwxyz"
self.vis = False
self.ignore_difficult = ignore_difficult
if self.ignore_difficult and (self.gts_dir is not None) and 'train' in self.gts_dir:
self.image_lists = self.filter_image_lists()
def filter_image_lists(self):
new_image_lists = []
for img_path in self.image_lists:
has_positive = False
im_name = os.path.basename(img_path)
gt_path = os.path.join(self.gts_dir, im_name + ".txt")
if not os.path.isfile(gt_path):
gt_path = os.path.join(
self.gts_dir, "gt_" + im_name.split(".")[0] + ".txt"
)
lines = open(gt_path, 'r').readlines()
for line in lines:
charbbs = []
strs, loc = self.line2boxes(line)
word = strs[0]
if word == "1":
continue
else:
has_positive = True
if has_positive:
new_image_lists.append(img_path)
return new_image_lists
def __getitem__(self, item):
im_name = os.path.basename(self.image_lists[item])
# print(self.image_lists[item])
img = Image.open(self.image_lists[item]).convert("RGB")
width, height = img.size
if self.gts_dir is not None:
gt_path = os.path.join(self.gts_dir, im_name + ".txt")
words, boxes, charsbbs, segmentations, labels = self.load_gt_from_txt(
gt_path, height, width
)
if words[0] == "":
use_char_ann = False
else:
use_char_ann = True
if not self.use_charann:
use_char_ann = False
target = BoxList(
boxes[:, :4], img.size, mode="xyxy", use_char_ann=use_char_ann
)
if self.ignore_difficult:
labels = torch.from_numpy(np.array(labels))
else:
labels = torch.ones(len(boxes))
target.add_field("labels", labels)
masks = SegmentationMask(segmentations, img.size)
target.add_field("masks", masks)
char_masks = SegmentationCharMask(
charsbbs, words=words, use_char_ann=use_char_ann, size=img.size, char_num_classes=len(self.char_classes)
)
target.add_field("char_masks", char_masks)
else:
target = None
if self.transforms is not None:
img, target = self.transforms(img, target)
if self.vis:
new_im = img.numpy().copy().transpose([1, 2, 0]) + [
102.9801,
115.9465,
122.7717,
]
new_im = Image.fromarray(new_im.astype(np.uint8)).convert("RGB")
mask = target.extra_fields["masks"].polygons[0].convert("mask")
mask = Image.fromarray((mask.numpy() * 255).astype(np.uint8)).convert("RGB")
if self.use_charann:
m, _ = (
target.extra_fields["char_masks"]
.chars_boxes[0]
.convert("char_mask")
)
color = self.creat_color_map(37, 255)
color_map = color[m.numpy().astype(np.uint8)]
char = Image.fromarray(color_map.astype(np.uint8)).convert("RGB")
char = Image.blend(char, new_im, 0.5)
else:
char = new_im
new = Image.blend(char, mask, 0.5)
img_draw = ImageDraw.Draw(new)
for box in target.bbox.numpy():
box = list(box)
box = box[:2] + [box[2], box[1]] + box[2:] + [box[0], box[3]] + box[:2]
img_draw.line(box, fill=(255, 0, 0), width=2)
new.save("./vis/char_" + im_name)
return img, target, self.image_lists[item]
def creat_color_map(self, n_class, width):
splits = int(np.ceil(np.power((n_class * 1.0), 1.0 / 3)))
maps = []
for i in range(splits):
r = int(i * width * 1.0 / (splits - 1))
for j in range(splits):
g = int(j * width * 1.0 / (splits - 1))
for k in range(splits - 1):
b = int(k * width * 1.0 / (splits - 1))
maps.append([r, g, b])
return np.array(maps)
def __len__(self):
return len(self.image_lists)
# def load_gt_from_txt(self, gt_path, height=None, width=None):
# words, boxes, charsboxes, segmentations, labels = [], [], [], [], []
# lines = open(gt_path).readlines()
# for line in lines:
# charbbs = []
# strs, loc = self.line2boxes(line)
# word = strs[0]
# if word == "###":
# labels.append(-1)
# continue
# else:
# labels.append(1)
# rect = list(loc[0])
# min_x = min(rect[::2]) - 1
# min_y = min(rect[1::2]) - 1
# max_x = max(rect[::2]) - 1
# max_y = max(rect[1::2]) - 1
# box = [min_x, min_y, max_x, max_y]
# segmentations.append([loc[0, :]])
# tindex = len(boxes)
# boxes.append(box)
# words.append(word)
# c_class = self.char2num(strs[1:])
# charbb = np.zeros((10,), dtype=np.float32)
# if loc.shape[0] > 1:
# for i in range(1, loc.shape[0]):
# charbb[:8] = loc[i, :]
# charbb[8] = c_class[i - 1]
# charbb[9] = tindex
# charbbs.append(charbb.copy())
# charsboxes.append(charbbs)
# num_boxes = len(boxes)
# if len(boxes) > 0:
# keep_boxes = np.zeros((num_boxes, 5))
# keep_boxes[:, :4] = np.array(boxes)
# keep_boxes[:, 4] = range(
# num_boxes
# ) # the 5th column is the box label,same as the 10th column of all charsboxes which belong to the box
# if self.use_charann:
# return words, np.array(keep_boxes), charsboxes, segmentations, labels
# else:
# charbbs = np.zeros((10,), dtype=np.float32)
# for i in range(len(words)):
# charsboxes.append([charbbs])
# return words, np.array(keep_boxes), charsboxes, segmentations, labels
# else:
# words.append("")
# charbbs = np.zeros((10,), dtype=np.float32)
# return (
# words,
# np.zeros((1, 5), dtype=np.float32),
# [[charbbs]],
# [[np.zeros((8,), dtype=np.float32)]],
# labels
# )
def load_gt_from_txt(self, gt_path, height=None, width=None):
words, boxes, charsboxes, segmentations, labels = [], [], [], [], []
lines = open(gt_path).readlines()
for line in lines:
charbbs = []
strs, loc = self.line2boxes(line)
word = strs[0]
if self.ignore_difficult:
rect = list(loc[0])
min_x = min(rect[::2]) - 1
min_y = min(rect[1::2]) - 1
max_x = max(rect[::2]) - 1
max_y = max(rect[1::2]) - 1
box = [min_x, min_y, max_x, max_y]
# segmentations.append([loc[0, :]])
segmentations.append([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]])
tindex = len(boxes)
boxes.append(box)
# words.append(word)
if word =='1':
labels.append(-1)
else:
labels.append(1)
charbb = np.zeros((10,), dtype=np.float32)
if loc.shape[0] > 1:
for i in range(1, loc.shape[0]):
charbb[9] = tindex
charbbs.append(charbb.copy())
charsboxes.append(charbbs)
else:
continue
num_boxes = len(boxes)
if len(boxes) > 0:
keep_boxes = np.zeros((num_boxes, 5))
keep_boxes[:, :4] = np.array(boxes)
keep_boxes[:, 4] = range(
num_boxes
)
# the 5th column is the box label,
# same as the 10th column of all charsboxes which belong to the box
if self.use_charann:
return words, np.array(keep_boxes), charsboxes, segmentations, labels
else:
charbbs = np.zeros((10,), dtype=np.float32)
if len(charsboxes) == 0:
for _ in range(len(words)):
charsboxes.append([charbbs])
return words, np.array(keep_boxes), charsboxes, segmentations, labels
else:
words.append("")
charbbs = np.zeros((10,), dtype=np.float32)
return (
words,
np.zeros((1, 5), dtype=np.float32),
[[charbbs]],
[[np.zeros((8,), dtype=np.float32)]],
[1]
)
def line2boxes(self, line):
parts = line.strip().split(",")
return [parts[-1]], np.array([[float(x) for x in parts[:-1]]])
def check_charbbs(self, charbbs):
xmins = np.minimum.reduce(
[charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]]
)
xmaxs = np.maximum.reduce(
[charbbs[:, 0], charbbs[:, 2], charbbs[:, 4], charbbs[:, 6]]
)
ymins = np.minimum.reduce(
[charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]]
)
ymaxs = np.maximum.reduce(
[charbbs[:, 1], charbbs[:, 3], charbbs[:, 5], charbbs[:, 7]]
)
return np.logical_and(
xmaxs - xmins > self.min_proposal_size,
ymaxs - ymins > self.min_proposal_size,
)
def check_charbb(self, charbb):
xmins = min(charbb[0], charbb[2], charbb[4], charbb[6])
xmaxs = max(charbb[0], charbb[2], charbb[4], charbb[6])
ymins = min(charbb[1], charbb[3], charbb[5], charbb[7])
ymaxs = max(charbb[1], charbb[3], charbb[5], charbb[7])
return (
xmaxs - xmins > self.min_proposal_size
and ymaxs - ymins > self.min_proposal_size
)
def char2num(self, chars):
## chars ['h', 'e', 'l', 'l', 'o']
nums = [self.char_classes.index(c.lower()) for c in chars]
return nums
def get_img_info(self, item):
"""
Return the image dimensions for the image, without
loading and pre-processing it
"""
im_name = os.path.basename(self.image_lists[item])
img = Image.open(self.image_lists[item])
width, height = img.size
img_info = {"im_name": im_name, "height": height, "width": width}
return img_info