| # #!/usr/bin/env python | |
| # # -*- encoding: utf-8 -*- | |
| # """ | |
| # @Author : Peike Li | |
| # @Contact : peike.li@yahoo.com | |
| # @File : simple_extractor.py | |
| # @Time : 8/30/19 8:59 PM | |
| # @Desc : Simple Extractor | |
| # @License : This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # """ | |
| # import os | |
| # import torch | |
| # import argparse | |
| # import numpy as np | |
| # from PIL import Image | |
| # from tqdm import tqdm | |
| # from torch.utils.data import DataLoader | |
| # import torchvision.transforms as transforms | |
| # import os | |
| # import sys | |
| # _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess | |
| # if _THIS_DIR not in sys.path: | |
| # sys.path.insert(0, _THIS_DIR) | |
| # import networks | |
| # from utils.transforms import transform_logits | |
| # from datasets.simple_extractor_dataset import SimpleFolderDataset | |
| # dataset_settings = { | |
| # 'lip': { | |
| # 'input_size': [473, 473], | |
| # 'num_classes': 20, | |
| # 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', | |
| # 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', | |
| # 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'] | |
| # }, | |
| # 'atr': { | |
| # 'input_size': [512, 512], | |
| # 'num_classes': 18, | |
| # 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', | |
| # 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] | |
| # }, | |
| # 'pascal': { | |
| # 'input_size': [512, 512], | |
| # 'num_classes': 7, | |
| # 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'], | |
| # } | |
| # } | |
| # def get_arguments(): | |
| # """Parse all the arguments provided from the CLI. | |
| # Returns: | |
| # A list of parsed arguments. | |
| # """ | |
| # parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") | |
| # parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal']) | |
| # parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.") | |
| # parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.") | |
| # parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).") | |
| # parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.") | |
| # parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.") | |
| # parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.") | |
| # return parser.parse_args() | |
| # def get_palette(num_cls): | |
| # n = 18 | |
| # palette = [0] * (n * 3) | |
| # j = num_cls | |
| # lab = num_cls | |
| # palette[j * 3 + 0] = 0 | |
| # palette[j * 3 + 1] = 0 | |
| # palette[j * 3 + 2] = 0 | |
| # i = 0 | |
| # while lab: | |
| # palette[j * 3 + 0] = 255 | |
| # palette[j * 3 + 1] = 255 | |
| # palette[j * 3 + 2] = 255 | |
| # i += 1 | |
| # lab >>= 3 | |
| # return palette | |
| # def get_palette2(num_cls): | |
| # """ Returns the color map for visualizing the segmentation mask. | |
| # Args: | |
| # num_cls: Number of classes | |
| # Returns: | |
| # The color map | |
| # """ | |
| # n = 18 | |
| # palette = [0] * (n * 3) | |
| # for j in range(5, 7): | |
| # lab = j | |
| # palette[j * 3 + 0] = 0 | |
| # palette[j * 3 + 1] = 0 | |
| # palette[j * 3 + 2] = 0 | |
| # i = 0 | |
| # while lab: | |
| # palette[j * 3 + 0] = 255 | |
| # palette[j * 3 + 1] = 255 | |
| # palette[j * 3 + 2] = 255 | |
| # i += 1 | |
| # lab >>= 3 | |
| # return palette | |
| # def run( | |
| # *, | |
| # category: str, | |
| # input_path: str = "", | |
| # input_dir: str = "", | |
| # dataset: str = "atr", | |
| # model_restore: str = "", | |
| # gpu: str = "0", | |
| # logits: bool = False, | |
| # ): | |
| # """ | |
| # - input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환. | |
| # - 파일 저장 없음. | |
| # Returns: | |
| # { | |
| # "images": List[PIL.Image], # parsing mask (palette 적용됨) | |
| # "logits": Optional[List[np.ndarray]], | |
| # "names": List[str], # 파일명들 | |
| # } | |
| # """ | |
| # # single GPU만 허용 | |
| # gpus = [int(i) for i in gpu.split(',')] | |
| # assert len(gpus) == 1 | |
| # if gpu != 'None': | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = gpu | |
| # if not model_restore: | |
| # print("[simple_extractor] model_restore not provided → skip extractor.") | |
| # return {"images": [], "logits": [] if logits else None, "names": []} | |
| # # 입력 검증: 둘 중 하나는 있어야 함 | |
| # if bool(input_path) == bool(input_dir): | |
| # raise ValueError("Provide exactly one of input_path or input_dir.") | |
| # # 파일이면 존재 확인 | |
| # if input_path: | |
| # if not os.path.isfile(input_path): | |
| # raise FileNotFoundError(f"input_path not found or not a file: {input_path}") | |
| # # 폴더면 존재 확인 | |
| # if input_dir: | |
| # if not os.path.isdir(input_dir): | |
| # raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}") | |
| # num_classes = dataset_settings[dataset]['num_classes'] | |
| # input_size = dataset_settings[dataset]['input_size'] | |
| # label = dataset_settings[dataset]['label'] | |
| # print(f"Evaluating total class number {num_classes} with {label}") | |
| # model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) | |
| # state_dict = torch.load(model_restore)['state_dict'] | |
| # from collections import OrderedDict | |
| # new_state_dict = OrderedDict() | |
| # for k, v in state_dict.items(): | |
| # name = k[7:] # remove `module.` | |
| # new_state_dict[name] = v | |
| # model.load_state_dict(new_state_dict) | |
| # model.cuda() | |
| # model.eval() | |
| # transform = transforms.Compose([ | |
| # transforms.ToTensor(), | |
| # transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) | |
| # ]) | |
| # # ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ---- | |
| # if input_path: | |
| # # root는 파일의 부모 디렉터리, file_list는 파일명 1개 | |
| # root = os.path.dirname(input_path) | |
| # file_list = [os.path.basename(input_path)] | |
| # else: | |
| # root = input_dir | |
| # file_list = sorted([ | |
| # f for f in os.listdir(root) | |
| # if f.lower().endswith(('.png', '.jpg', '.jpeg')) | |
| # ]) | |
| # dataset_obj = SimpleFolderDataset( | |
| # root=root, | |
| # input_size=input_size, | |
| # transform=transform, | |
| # file_list=file_list | |
| # ) | |
| # dataloader = DataLoader(dataset_obj) | |
| # palette = get_palette(4) | |
| # results_img = [] | |
| # results_logits = [] if logits else None | |
| # names = [] | |
| # with torch.no_grad(): | |
| # for batch in tqdm(dataloader): | |
| # image, meta = batch | |
| # img_name = meta['name'][0] | |
| # names.append(img_name) | |
| # c = meta['center'].numpy()[0] | |
| # s = meta['scale'].numpy()[0] | |
| # w = meta['width'].numpy()[0] | |
| # h = meta['height'].numpy()[0] | |
| # output = model(image.cuda()) | |
| # upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) | |
| # upsample_output = upsample(output[0][-1][0].unsqueeze(0)) | |
| # upsample_output = upsample_output.squeeze() | |
| # upsample_output = upsample_output.permute(1, 2, 0) | |
| # logits_result = transform_logits( | |
| # upsample_output.data.cpu().numpy(), | |
| # c, s, w, h, | |
| # input_size=input_size | |
| # ) | |
| # parsing_result = np.argmax(logits_result, axis=2) | |
| # out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) | |
| # out_img.putpalette(palette) | |
| # results_img.append(out_img) | |
| # if logits: | |
| # results_logits.append(logits_result) | |
| # return {"images": results_img, "logits": results_logits, "names": names} | |
| # def main(): | |
| # # ✅ CLI 호환 유지 | |
| # args = get_arguments() | |
| # run( | |
| # category=args.category, | |
| # input_dir=args.input_dir, | |
| # output_dir=args.output_dir, | |
| # ) | |
| # if __name__ == '__main__': | |
| # main() | |
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| """ | |
| @Author : Peike Li | |
| @Contact : peike.li@yahoo.com | |
| @File : simple_extractor.py | |
| @Desc : Simple Extractor (category-aware palette selection) | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| import torchvision.transforms as transforms | |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if _THIS_DIR not in sys.path: | |
| sys.path.insert(0, _THIS_DIR) | |
| import networks | |
| from utils.transforms import transform_logits | |
| from datasets.simple_extractor_dataset import SimpleFolderDataset | |
| dataset_settings = { | |
| 'lip': { | |
| 'input_size': [473, 473], | |
| 'num_classes': 20, | |
| 'label': [ | |
| 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', | |
| 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', | |
| 'Jumpsuits', 'Scarf', 'Skirt', 'Face', | |
| 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', | |
| 'Left-shoe', 'Right-shoe' | |
| ] | |
| }, | |
| 'atr': { | |
| 'input_size': [512, 512], | |
| 'num_classes': 18, | |
| 'label': [ | |
| 'Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', | |
| 'Skirt', 'Pants', 'Dress', 'Belt', | |
| 'Left-shoe', 'Right-shoe', 'Face', | |
| 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', | |
| 'Bag', 'Scarf' | |
| ] | |
| }, | |
| 'pascal': { | |
| 'input_size': [512, 512], | |
| 'num_classes': 7, | |
| 'label': [ | |
| 'Background', 'Head', 'Torso', | |
| 'Upper Arms', 'Lower Arms', | |
| 'Upper Legs', 'Lower Legs' | |
| ], | |
| } | |
| } | |
| def get_arguments(): | |
| parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") | |
| parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal']) | |
| parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.") | |
| parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.") | |
| parser.add_argument("--category", type=str, default='Upper-cloth', help="category name.") | |
| parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.") | |
| parser.add_argument("--output-dir", type=str, default='', help="(unused, kept for CLI compatibility)") | |
| parser.add_argument("--logits", action='store_true', default=False) | |
| return parser.parse_args() | |
| def get_palette(num_cls): | |
| n = 18 | |
| palette = [0] * (n * 3) | |
| j = num_cls | |
| lab = num_cls | |
| palette[j * 3 + 0] = 0 | |
| palette[j * 3 + 1] = 0 | |
| palette[j * 3 + 2] = 0 | |
| while lab: | |
| palette[j * 3 + 0] = 255 | |
| palette[j * 3 + 1] = 255 | |
| palette[j * 3 + 2] = 255 | |
| lab >>= 3 | |
| return palette | |
| def get_palette2(num_cls): | |
| n = 18 | |
| palette = [0] * (n * 3) | |
| for j in range(5, 7): | |
| lab = j | |
| palette[j * 3 + 0] = 0 | |
| palette[j * 3 + 1] = 0 | |
| palette[j * 3 + 2] = 0 | |
| while lab: | |
| palette[j * 3 + 0] = 255 | |
| palette[j * 3 + 1] = 255 | |
| palette[j * 3 + 2] = 255 | |
| lab >>= 3 | |
| return palette | |
| def _select_palette_by_category(category: str): | |
| """ | |
| category별 palette 선택 로직 (명시적 규칙) | |
| """ | |
| if category == "Upper-cloth": | |
| return get_palette(4) | |
| elif category == "Bottom": | |
| return get_palette2(4) | |
| elif category == "Dress": | |
| return get_palette(7) | |
| else: | |
| # fallback (명시 안 된 카테고리) | |
| return get_palette(7) | |
| def run( | |
| *, | |
| category: str, | |
| input_path: str = "", | |
| input_dir: str = "", | |
| dataset: str = "atr", | |
| model_restore: str = "", | |
| gpu: str = "0", | |
| logits: bool = False, | |
| ): | |
| """ | |
| Returns: | |
| { | |
| "images": List[PIL.Image], | |
| "logits": Optional[List[np.ndarray]], | |
| "names": List[str], | |
| } | |
| """ | |
| gpus = [int(i) for i in gpu.split(',')] | |
| assert len(gpus) == 1 | |
| if gpu != 'None': | |
| os.environ["CUDA_VISIBLE_DEVICES"] = gpu | |
| if not model_restore: | |
| print("[simple_extractor] model_restore not provided → skip extractor.") | |
| return {"images": [], "logits": [] if logits else None, "names": []} | |
| if bool(input_path) == bool(input_dir): | |
| raise ValueError("Provide exactly one of input_path or input_dir.") | |
| if input_path and not os.path.isfile(input_path): | |
| raise FileNotFoundError(input_path) | |
| if input_dir and not os.path.isdir(input_dir): | |
| raise NotADirectoryError(input_dir) | |
| num_classes = dataset_settings[dataset]['num_classes'] | |
| input_size = dataset_settings[dataset]['input_size'] | |
| model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) | |
| state_dict = torch.load(model_restore)['state_dict'] | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_state_dict[k[7:]] = v | |
| model.load_state_dict(new_state_dict) | |
| model.cuda() | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.406, 0.456, 0.485], | |
| std=[0.225, 0.224, 0.229]) | |
| ]) | |
| if input_path: | |
| root = os.path.dirname(input_path) | |
| file_list = [os.path.basename(input_path)] | |
| else: | |
| root = input_dir | |
| file_list = sorted([ | |
| f for f in os.listdir(root) | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg')) | |
| ]) | |
| dataset_obj = SimpleFolderDataset( | |
| root=root, | |
| input_size=input_size, | |
| transform=transform, | |
| file_list=file_list | |
| ) | |
| dataloader = DataLoader(dataset_obj) | |
| # ✅ 핵심 수정: category 기반 palette 선택 | |
| palette = _select_palette_by_category(category) | |
| results_img = [] | |
| results_logits = [] if logits else None | |
| names = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader): | |
| image, meta = batch | |
| img_name = meta['name'][0] | |
| names.append(img_name) | |
| c = meta['center'].numpy()[0] | |
| s = meta['scale'].numpy()[0] | |
| w = meta['width'].numpy()[0] | |
| h = meta['height'].numpy()[0] | |
| output = model(image.cuda()) | |
| upsample = torch.nn.Upsample( | |
| size=input_size, mode='bilinear', align_corners=True | |
| ) | |
| upsample_output = upsample(output[0][-1][0].unsqueeze(0)) | |
| upsample_output = upsample_output.squeeze().permute(1, 2, 0) | |
| logits_result = transform_logits( | |
| upsample_output.data.cpu().numpy(), | |
| c, s, w, h, | |
| input_size=input_size | |
| ) | |
| parsing_result = np.argmax(logits_result, axis=2) | |
| out_img = Image.fromarray(parsing_result.astype(np.uint8)) | |
| out_img.putpalette(palette) | |
| results_img.append(out_img) | |
| if logits: | |
| results_logits.append(logits_result) | |
| return { | |
| "images": results_img, | |
| "logits": results_logits, | |
| "names": names | |
| } | |
| def main(): | |
| args = get_arguments() | |
| run( | |
| category=args.category, | |
| input_dir=args.input_dir, | |
| dataset=args.dataset, | |
| model_restore=args.model_restore, | |
| gpu=args.gpu, | |
| logits=args.logits, | |
| ) | |
| if __name__ == '__main__': | |
| main() | |