REC-MV_preprocess / parsing_mask_to_fl.py
mambazjp's picture
Upload 11 files
ef6a8c6
raw
history blame
14.4 kB
"""
@File: parsing_mask_to_fl.py
@Author: Lingteng Qiu
@Email: qiulingteng@link.cuhk.edu.cn
@Date: 2022-10-19
@Desc: parsing mask to polygons, given a series of mask infos
"""
import sys
sys.path.extend('./')
import argparse
import os
import glob
import os.path as osp
import numpy as np
import cv2
import pdb
import cv2
import json
from pytorch3d.ops.knn import knn_points
import torch
from utils.constant import FL_EXTRACT, TEMPLATE_GARMENT
'''
关于如何从parsing_mask提取featureline:
需要 parse_mask,标的两个点的json文件,但是如果出现了手遮挡住腰间的情况(或者头发遮挡neck),那么就要标注六个点绕过去
'''
FL_COLOR = {
'neck':(0, 0, 255),
'right_cuff': (0, 255, 0),
'left_cuff':(255, 0, 0),
'left_pant': (127, 127, 0),
'right_pant':(0, 127, 127),
'upper_bottom': (127, 0, 127),
'bottom_curve':(0, 127, 127),
}
def draw_lines(pts, color, img=None):
#pdb.set_trace()
if img is None:
img = np.zeros((1080, 1080, 3), np.uint8)
for i in range(len(pts)-1):
cv2.line(img, (int(pts[i][0]), int(pts[i][1])), (int(pts[i+1][0]), int(pts[i+1][1])), color, 2)
return img
def draw_pts(pts, color, img=None):
#pdb.set_trace()
if img is None:
img = np.zeros((1080, 1080, 3), np.uint8)
for i, pt in enumerate(pts):
cv2.circle(img, (int(pt[0]), int(pt[1])), 2, color, -1)
return img
class PolyMask(object):
def __init__(self, mask):
self.mask = mask
def query(self,query_sets ,labels, garment_key):
'''
query_sets 必须要偶数个 why?
1、计算featureline和mask的交点
2、如果交点数 > 2,找到一个交点对(两个交点),线段在特征线上并且在mask中的部分最长
3、通过这个交点对来获得featureline: 交点对 + 两点之间的mask边界
Args:
labels: only cloth?
'''
mask = np.zeros_like(self.mask, dtype= np.bool)
#pdb.set_trace()
for label in labels:
label_mask = np.zeros_like(self.mask, dtype =np.bool)
#pdb.set_trace()
i,j = np.where(self.mask == label)
label_mask[i,j] = True
mask |= label_mask
# [0, 255]
mask = mask.astype(np.uint8)*255
mask = self.smooth_noise(mask)
mask_polygons, mask_area = self.mask2polygon(mask)
# img_org = cv2.imread('./debug/多边形.png', -1)
# img=draw_lines(query_sets['neck'], (0, 255, 0), img_org)
length_dp = []
for mask_polygon in mask_polygons:
#pdb.set_trace()
dis = [0]
dis.extend([abs(mask_polygon[p_i][0]- mask_polygon[p_i+1][0]) + abs(mask_polygon[p_i][1]- mask_polygon[p_i+1][1]) for p_i in range(mask_polygon.shape[0]-1)])
dis.append(abs(mask_polygon[0][0]- mask_polygon[-1][0]) + abs(mask_polygon[0][1]- mask_polygon[-1][1]))
# 累计距离
dp = np.cumsum(dis)
length_dp.append(dp)
new_query_sets = {}
reply_pts = np.concatenate(mask_polygons, axis=0)
reply_pts = torch.from_numpy(reply_pts).float().cuda()
#pdb.set_trace()
for key in query_sets.keys():
polygon = query_sets[key]
# len(featureline points) % 2 == 0
assert polygon.shape[0] % 2 == 0
# 两两组合
polygons = polygon.reshape(-1, 2, 2)
group = []
for group_id, mask_polygon in enumerate(mask_polygons):
group.extend([group_id for i in range(mask_polygon.shape[0])])
group = torch.tensor(group).long()
new_polygons=[]
pre_polygon = None
# 循环每个顶点对
for polygon in polygons:
polygon = torch.from_numpy(polygon).float().cuda()
if pre_polygon is not None:
dis = torch.sqrt(((polygon[0] - pre_polygon[-1]) **2).sum())
# if two pts are close, directly add the polygon the new_polygons, avoid the situation that the hands or hair block the mask
if dis < 10:
new_polygons.append(polygon.detach().cpu().numpy())
pre_polygon = None
continue
pre_polygon = polygon.detach().clone()
# find the nearest edge in mask of the featureline
dist = knn_points(polygon[None], reply_pts[None])
idx = dist.idx[0, ...,0]
group_id = group[idx]
if dist.dists.max()>1000:
new_polygons.append(polygon.detach().cpu().numpy())
continue
# pick the id which is in a larger mask area => pick which the mask area
prefer_id = group_id[0] if mask_area[group_id[0]] > mask_area[group_id[1]] else group_id[1]
prefer_pts = torch.from_numpy(mask_polygons[prefer_id]).float().cuda()
# find the nearest edge in mask of the featureline two points
dist = knn_points(polygon[None], prefer_pts[None])
idx = dist.idx[0, ...,0].sort()
polygon= polygon[idx.indices]
idx=idx.values
reverse_flag = (not idx[0] == dist.idx[0, 0, 0])
# obtain slice_curve
dp = length_dp[prefer_id]
# compute the length of the two points in mask edge
# find the shortest path
slice_a = dp[idx[1]] - dp[idx[0]]
slice_b = dp[-1] - slice_a
#obtain slice_b
if slice_a>slice_b:
# 找最短的路径,当前这个edge的后一个点
segment = torch.cat([polygon[1:],prefer_pts[idx[1]:], prefer_pts[:idx[0]+1], polygon[0:1]], dim = 0)
reverse_flag = (not reverse_flag)
else:
segment = torch.cat([polygon[0:1], prefer_pts[idx[0]:idx[1]+1], polygon[1:]], dim=0)
segment = segment.detach().cpu().numpy()
if reverse_flag:
segment = segment[::-1]
new_polygons.append(segment)
new_polygons = np.concatenate(new_polygons, axis = 0)
new_query_sets[key] = new_polygons
#pdb.set_trace()
return new_query_sets, mask
def smooth_noise(self, mask):
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5, 5))
mask = cv2.erode(mask, kernel, iterations=2)
mask = cv2.dilate(mask, kernel, iterations=2)
return mask
def mask2polygon(self, mask):
contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
# mask_new, contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
segmentation = []
polygon_size = []
for contour in contours:
contour_list = contour.flatten().tolist()
if len(contour_list) > 4:# and cv2.contourArea(contour)>10000
area = self.polygons_to_mask(mask.shape, contour_list).sum()
polygon_size.append(area)
contour_numpy = np.asarray(contour_list).reshape(-1, 2)
segmentation.append(contour_numpy)
return segmentation, polygon_size
def polygons_to_mask(self, img_shape, polygons):
mask = np.zeros(img_shape, dtype=np.uint8)
polygons = np.asarray(polygons, np.int32) # 这里必须是int32,其他类型使用fillPoly会报错
shape=polygons.shape
polygons=polygons.reshape(-1,2)
cv2.fillPoly(mask, [polygons],color=1) # 非int32 会报错
return mask
def get_upper_bttom_type(parsing_type, key):
# 'ATR': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
# 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'],
ATR_PARSING = {
'upper':[4, 16, 17], # 4: upper-clothes, 16: Bag, 17: Scarf
#'upper':[4], # 4: upper-clothes, 16: Bag, 17: Scarf
#without head
# 'upper':[1, 2, 3, 4, 11, 16, 17],
'bottom':[5, 6, 8],
#'bottom':[5, 6 ], # 5: skirt, 6: pants, 8: belt
# with head and hand
#'upper_bottom':[4, 5, 6, 7 ]
'upper_bottom':[4, 5, 6, 7, 8, 16, 17]
}
CLO_PARSING = {
# with head and hand
'upper':[1,2,3],
#without head
# 'upper':[1, 2, 3, 4, 11, 16, 17],
'bottom':[1,2,3],
# with head and hand
'upper_bottom':[1,2,3]
# w/o hand
# 'upper_bottom':[4, 5, 7, 16, 17]
}
if parsing_type =='ATR':
return ATR_PARSING[key]
else:
return CLO_PARSING[key]
def get_parsing_label(parsing_type):
parsing_table ={
'ATR': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'],
'CLO':['background', 'upper', 'bottom', 'upper-bottom']
}
return parsing_table[parsing_type]
def get_parse():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--parsing_type', default='ATR', help='garment_parsing type', choices=['ATR', 'CLO'])
parser.add_argument('--input_path', default='', help='select model')
parser.add_argument('--output_path', default='', help='polygons output')
args = parser.parse_args()
return args
def parsing_curve(query_file, parsing_file, parsing_type, class_type, debug_path, name):
query_sets = {}
with open(query_file) as reader:
fl_infos = json.load(reader)
shapes = fl_infos['shapes']
for fl in shapes:
query_sets[fl['label']] = np.asarray(fl['points']).astype(np.float32)
class_table = dict(
female_outfit3=['upper_bottom'],
female_outfit1=['upper_bottom'],
anran_run = ['short_sleeve_upper', 'skirt'],
anran_tic = ['short_sleeve_upper', 'skirt'],
leyang_jump = ['dress'],
leyang_steps = ['dress'],
)
garment_table = dict(
short_sleeve_upper='upper',
skirt='bottom',
dress='upper_bottom',
long_sleeve_upper='upper',
long_pants='bottom',
short_pants='bottom',
)
masks = np.load(parsing_file, allow_pickle= True) # [H, W]
parsing_name = parsing_file.split('/')[-1]
poly_mask = PolyMask(masks)
new_query_sets = {}
for garment_key in TEMPLATE_GARMENT[class_type]:
pdb.set_trace()
garment_class = get_upper_bttom_type(parsing_type, garment_table[garment_key])
fl_names = 'bottom_curve' #FL_EXTRACT[garment_key]
fl_query_sets = {}
for fl_name in fl_names:
if fl_name in query_sets.keys():
fl_query_sets[fl_name] = query_sets[fl_name]
#pdb.set_trace()
new_fl_query_sets, mask = poly_mask.query(fl_query_sets, garment_class, garment_key)
new_query_sets.update(new_fl_query_sets)
cv2.imwrite(osp.join(debug_path, 'mask_{}_'.format(garment_key)+name), mask)
return new_query_sets, mask
def main(args):
# ATR, CLO
parsing_type = args.parsing_type
parsing_label = get_parsing_label(parsing_type)
parsing_dir = osp.join(args.input_path, 'parsing_SCH_{}'.format(parsing_type))
img_dir = osp.join(args.input_path, 'imgs/')
#json_files = sorted(glob.glob(osp.join(args.input_path, 'featurelines/*.json')))
json_files = sorted(glob.glob(osp.join(args.input_path, 'json_hand_label_no_bottom_curve/*.json')))
img_files = sorted(glob.glob(osp.join(img_dir, '*.jpg')))
img_files += sorted(glob.glob(osp.join(img_dir, '*.png')))
# get the id: 000342
json_key = [json_file.split('/')[-1][:-5] for json_file in json_files]
parsing_files = sorted(glob.glob(osp.join(parsing_dir,'*.npy')))
# given the small json files (less than no.imgs), find the corresponding parsing files and img files
filter_parsing_files = list(filter(lambda x: x.split('/')[-1].split('_')[-1][:-4] in json_key, parsing_files))
filter_img_files = list(filter(lambda x: x.split('/')[-1][:-4] in json_key, img_files))
if args.input_path[-1] =='/':
input_path = args.input_path[:-1]
else:
input_path = args.input_path
class_type = input_path.split('/')[-1]
debug_path = osp.join('./debug/{}/polymask'.format(class_type))
output_path = args.output_path
os.makedirs(output_path, exist_ok = True)
os.makedirs(debug_path, exist_ok= True)
for idx, (parsing_file, json_file, filter_img_file) in enumerate(zip(filter_parsing_files, json_files, filter_img_files)):
pdb.set_trace()
print('processing: {}'.format(filter_img_file))
img = cv2.imread(filter_img_file)
name = filter_img_file.split('/')[-1]
# if idx == 5:
# pdb.set_trace()
new_query_sets, mask = parsing_curve(json_file, parsing_file, args.parsing_type, class_type, debug_path, name)
with open(json_file) as reader:
fl_infos = json.load(reader)
shapes = fl_infos['shapes']
for fl in shapes:
# query_sets[fl['label']] = np.asarray(fl['points']).astype(np.float32)
fl['points']= new_query_sets[fl['label']].tolist()
json_name = json_file.split('/')[-1]
new_json_file = os.path.join(output_path, json_name)
with open(new_json_file, 'w') as writer:
json.dump(fl_infos, writer)
for key in new_query_sets.keys():
color = FL_COLOR[key]
pt_list = new_query_sets[key].astype(np.int)
for pt in new_query_sets[key].astype(np.int):
img = cv2.circle(img, (pt[0], pt[1]),2, color,2)
for pt_idx in range(pt_list.shape[0]-1):
img = cv2.line(img, (pt_list[pt_idx][0],pt_list[pt_idx][1]), (pt_list[pt_idx+1][0],pt_list[pt_idx+1][1]), color, 2)
cv2.imwrite(osp.join(debug_path, name), img)
if __name__ == '__main__':
args = get_parse()
main(args)