IDM-VTONS / apply_net.py
IDM-VTON
update IDM-VTON Demo
938e515
raw
history blame
13 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import glob
import logging
import os
import sys
from typing import Any, ClassVar, Dict, List
import torch
from detectron2.config import CfgNode, get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from detectron2.utils.logger import setup_logger
from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import (
DensePoseOutputsTextureVisualizer,
DensePoseOutputsVertexVisualizer,
get_texture_atlases,
)
from densepose.vis.densepose_results import (
DensePoseResultsContourVisualizer,
DensePoseResultsFineSegmentationVisualizer,
DensePoseResultsUVisualizer,
DensePoseResultsVVisualizer,
)
from densepose.vis.densepose_results_textures import (
DensePoseResultsVisualizerWithTexture,
get_texture_atlas,
)
from densepose.vis.extractor import (
CompoundExtractor,
DensePoseOutputsExtractor,
DensePoseResultExtractor,
create_extractor,
)
DOC = """Apply Net - a tool to print / visualize DensePose results
"""
LOGGER_NAME = "apply_net"
logger = logging.getLogger(LOGGER_NAME)
_ACTION_REGISTRY: Dict[str, "Action"] = {}
class Action:
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
parser.add_argument(
"-v",
"--verbosity",
action="count",
help="Verbose mode. Multiple -v options increase the verbosity.",
)
def register_action(cls: type):
"""
Decorator for action classes to automate action registration
"""
global _ACTION_REGISTRY
_ACTION_REGISTRY[cls.COMMAND] = cls
return cls
class InferenceAction(Action):
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(InferenceAction, cls).add_arguments(parser)
parser.add_argument("cfg", metavar="<config>", help="Config file")
parser.add_argument("model", metavar="<model>", help="Model file")
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
@classmethod
def execute(cls: type, args: argparse.Namespace, human_img):
logger.info(f"Loading config from {args.cfg}")
opts = []
cfg = cls.setup_config(args.cfg, args.model, args, opts)
logger.info(f"Loading model from {args.model}")
predictor = DefaultPredictor(cfg)
# logger.info(f"Loading data from {args.input}")
# file_list = cls._get_input_file_list(args.input)
# if len(file_list) == 0:
# logger.warning(f"No input images for {args.input}")
# return
context = cls.create_context(args, cfg)
# for file_name in file_list:
# img = read_image(file_name, format="BGR") # predictor expects BGR image.
with torch.no_grad():
outputs = predictor(human_img)["instances"]
out_pose = cls.execute_on_outputs(context, {"image": human_img}, outputs)
cls.postexecute(context)
return out_pose
@classmethod
def setup_config(
cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
):
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file(config_fpath)
cfg.merge_from_list(args.opts)
if opts:
cfg.merge_from_list(opts)
cfg.MODEL.WEIGHTS = model_fpath
cfg.freeze()
return cfg
@classmethod
def _get_input_file_list(cls: type, input_spec: str):
if os.path.isdir(input_spec):
file_list = [
os.path.join(input_spec, fname)
for fname in os.listdir(input_spec)
if os.path.isfile(os.path.join(input_spec, fname))
]
elif os.path.isfile(input_spec):
file_list = [input_spec]
else:
file_list = glob.glob(input_spec)
return file_list
@register_action
class DumpAction(InferenceAction):
"""
Dump action that outputs results to a pickle file
"""
COMMAND: ClassVar[str] = "dump"
@classmethod
def add_parser(cls: type, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.")
cls.add_arguments(parser)
parser.set_defaults(func=cls.execute)
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(DumpAction, cls).add_arguments(parser)
parser.add_argument(
"--output",
metavar="<dump_file>",
default="results.pkl",
help="File name to save dump to",
)
@classmethod
def execute_on_outputs(
cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
):
image_fpath = entry["file_name"]
logger.info(f"Processing {image_fpath}")
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
context["results"].append(result)
@classmethod
def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode):
context = {"results": [], "out_fname": args.output}
return context
@classmethod
def postexecute(cls: type, context: Dict[str, Any]):
out_fname = context["out_fname"]
out_dir = os.path.dirname(out_fname)
if len(out_dir) > 0 and not os.path.exists(out_dir):
os.makedirs(out_dir)
with open(out_fname, "wb") as hFile:
torch.save(context["results"], hFile)
logger.info(f"Output saved to {out_fname}")
@register_action
class ShowAction(InferenceAction):
"""
Show action that visualizes selected entries on an image
"""
COMMAND: ClassVar[str] = "show"
VISUALIZERS: ClassVar[Dict[str, object]] = {
"dp_contour": DensePoseResultsContourVisualizer,
"dp_segm": DensePoseResultsFineSegmentationVisualizer,
"dp_u": DensePoseResultsUVisualizer,
"dp_v": DensePoseResultsVVisualizer,
"dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
"dp_cse_texture": DensePoseOutputsTextureVisualizer,
"dp_vertex": DensePoseOutputsVertexVisualizer,
"bbox": ScoredBoundingBoxVisualizer,
}
@classmethod
def add_parser(cls: type, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
cls.add_arguments(parser)
parser.set_defaults(func=cls.execute)
@classmethod
def add_arguments(cls: type, parser: argparse.ArgumentParser):
super(ShowAction, cls).add_arguments(parser)
parser.add_argument(
"visualizations",
metavar="<visualizations>",
help="Comma separated list of visualizations, possible values: "
"[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
)
parser.add_argument(
"--min_score",
metavar="<score>",
default=0.8,
type=float,
help="Minimum detection score to visualize",
)
parser.add_argument(
"--nms_thresh", metavar="<threshold>", default=None, type=float, help="NMS threshold"
)
parser.add_argument(
"--texture_atlas",
metavar="<texture_atlas>",
default=None,
help="Texture atlas file (for IUV texture transfer)",
)
parser.add_argument(
"--texture_atlases_map",
metavar="<texture_atlases_map>",
default=None,
help="JSON string of a dict containing texture atlas files for each mesh",
)
parser.add_argument(
"--output",
metavar="<image_file>",
default="outputres.png",
help="File name to save output to",
)
@classmethod
def setup_config(
cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
):
opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST")
opts.append(str(args.min_score))
if args.nms_thresh is not None:
opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST")
opts.append(str(args.nms_thresh))
cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts)
return cfg
@classmethod
def execute_on_outputs(
cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
):
import cv2
import numpy as np
visualizer = context["visualizer"]
extractor = context["extractor"]
# image_fpath = entry["file_name"]
# logger.info(f"Processing {image_fpath}")
image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY)
image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
data = extractor(outputs)
image_vis = visualizer.visualize(image, data)
return image_vis
entry_idx = context["entry_idx"] + 1
out_fname = './image-densepose/' + image_fpath.split('/')[-1]
out_dir = './image-densepose'
out_dir = os.path.dirname(out_fname)
if len(out_dir) > 0 and not os.path.exists(out_dir):
os.makedirs(out_dir)
cv2.imwrite(out_fname, image_vis)
logger.info(f"Output saved to {out_fname}")
context["entry_idx"] += 1
@classmethod
def postexecute(cls: type, context: Dict[str, Any]):
pass
# python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/DressCode/upper_body/images dp_segm -v --opts MODEL.DEVICE cpu
@classmethod
def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
base, ext = os.path.splitext(fname_base)
return base + ".{0:04d}".format(entry_idx) + ext
@classmethod
def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]:
vis_specs = args.visualizations.split(",")
visualizers = []
extractors = []
for vis_spec in vis_specs:
texture_atlas = get_texture_atlas(args.texture_atlas)
texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
vis = cls.VISUALIZERS[vis_spec](
cfg=cfg,
texture_atlas=texture_atlas,
texture_atlases_dict=texture_atlases_dict,
)
visualizers.append(vis)
extractor = create_extractor(vis)
extractors.append(extractor)
visualizer = CompoundVisualizer(visualizers)
extractor = CompoundExtractor(extractors)
context = {
"extractor": extractor,
"visualizer": visualizer,
"out_fname": args.output,
"entry_idx": 0,
}
return context
def create_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=DOC,
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
)
parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
subparsers = parser.add_subparsers(title="Actions")
for _, action in _ACTION_REGISTRY.items():
action.add_parser(subparsers)
return parser
def main():
parser = create_argument_parser()
args = parser.parse_args()
verbosity = getattr(args, "verbosity", None)
global logger
logger = setup_logger(name=LOGGER_NAME)
logger.setLevel(verbosity_to_level(verbosity))
args.func(args)
if __name__ == "__main__":
main()
# python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/Dresscode/dresses/humanonly dp_segm -v --opts MODEL.DEVICE cuda