# Copyright (c) Tencent Inc. All rights reserved. import os import cv2 import argparse import os.path as osp import torch from mmengine.config import Config, DictAction from mmengine.runner import Runner from mmengine.runner.amp import autocast from mmengine.dataset import Compose from mmengine.utils import ProgressBar from mmyolo.registry import RUNNERS import supervision as sv BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() LABEL_ANNOTATOR = sv.LabelAnnotator() MASK_ANNOTATOR = sv.MaskAnnotator() def parse_args(): parser = argparse.ArgumentParser(description='YOLO-World Demo') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('image', help='image path, include image file or dir.') parser.add_argument( 'text', help= 'text prompts, including categories separated by a comma or a txt file with each line as a prompt.' ) parser.add_argument('--topk', default=100, type=int, help='keep topk predictions.') parser.add_argument('--threshold', default=0.0, type=float, help='confidence score threshold for predictions.') parser.add_argument('--device', default='cuda:0', help='device used for inference.') parser.add_argument('--show', action='store_true', help='show the detection results.') parser.add_argument( '--annotation', action='store_true', help='save the annotated detection results as yolo text format.') parser.add_argument('--amp', action='store_true', help='use mixed precision for inference.') parser.add_argument('--output-dir', default='demo_outputs', help='the directory to save outputs') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() return args def inference_detector(runner, image_path, texts, max_dets, score_thr, output_dir, use_amp=False, show=False, annotation=False): data_info = dict(img_id=0, img_path=image_path, texts=texts) data_info = runner.pipeline(data_info) data_batch = dict(inputs=data_info['inputs'].unsqueeze(0), data_samples=[data_info['data_samples']]) with autocast(enabled=use_amp), torch.no_grad(): output = runner.model.test_step(data_batch)[0] pred_instances = output.pred_instances pred_instances = pred_instances[pred_instances.scores.float() > score_thr] if len(pred_instances.scores) > max_dets: indices = pred_instances.scores.float().topk(max_dets)[1] pred_instances = pred_instances[indices] pred_instances = pred_instances.cpu().numpy() if 'masks' in pred_instances: masks = pred_instances['masks'] else: masks = None detections = sv.Detections(xyxy=pred_instances['bboxes'], class_id=pred_instances['labels'], confidence=pred_instances['scores'], mask=masks) labels = [ f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in zip(detections.class_id, detections.confidence) ] # label images image = cv2.imread(image_path) anno_image = image.copy() image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections) image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels) if masks is not None: image = MASK_ANNOTATOR.annotate(image, detections) cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image) if annotation: images_dict = {} annotations_dict = {} images_dict[osp.basename(image_path)] = anno_image annotations_dict[osp.basename(image_path)] = detections ANNOTATIONS_DIRECTORY = os.makedirs(r"./annotations", exist_ok=True) MIN_IMAGE_AREA_PERCENTAGE = 0.002 MAX_IMAGE_AREA_PERCENTAGE = 0.80 APPROXIMATION_PERCENTAGE = 0.75 sv.DetectionDataset( classes=texts, images=images_dict, annotations=annotations_dict).as_yolo( annotations_directory_path=ANNOTATIONS_DIRECTORY, min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE, max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE, approximation_percentage=APPROXIMATION_PERCENTAGE) if show: cv2.imshow('Image', image) # Provide window name k = cv2.waitKey(0) if k == 27: # wait for ESC key to exit cv2.destroyAllWindows() if __name__ == '__main__': args = parse_args() # load config cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) cfg.load_from = args.checkpoint if 'runner_type' not in cfg: runner = Runner.from_cfg(cfg) else: runner = RUNNERS.build(cfg) # load text if args.text.endswith('.txt'): with open(args.text) as f: lines = f.readlines() texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']] else: texts = [[t.strip()] for t in args.text.split(',')] + [[' ']] output_dir = args.output_dir if not osp.exists(output_dir): os.mkdir(output_dir) runner.call_hook('before_run') runner.load_or_resume() pipeline = cfg.test_dataloader.dataset.pipeline runner.pipeline = Compose(pipeline) runner.model.eval() if not osp.isfile(args.image): images = [ osp.join(args.image, img) for img in os.listdir(args.image) if img.endswith('.png') or img.endswith('.jpg') ] else: images = [args.image] progress_bar = ProgressBar(len(images)) for image_path in images: inference_detector(runner, image_path, texts, args.topk, args.threshold, output_dir=output_dir, use_amp=args.amp, show=args.show, annotation=args.annotation) progress_bar.update()