Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
from abc import ABCMeta | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from scepter.modules.annotator.base_annotator import BaseAnnotator | |
from scepter.modules.annotator.registry import ANNOTATORS | |
from scepter.modules.utils.config import dict_to_yaml | |
class ColorAnnotator(BaseAnnotator, metaclass=ABCMeta): | |
para_dict = {} | |
def __init__(self, cfg, logger=None): | |
super().__init__(cfg, logger=logger) | |
self.ratio = cfg.get('RATIO', 64) | |
self.random_cfg = cfg.get('RANDOM_CFG', None) | |
def forward(self, image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
elif isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
elif isinstance(image, np.ndarray): | |
image = image.copy() | |
else: | |
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' | |
h, w = image.shape[:2] | |
if self.random_cfg is None: | |
ratio = self.ratio | |
else: | |
proba = self.random_cfg.get('PROBA', 1.0) | |
if np.random.random() < proba: | |
if 'CHOICE_RATIO' in self.random_cfg: | |
ratio = np.random.choice(self.random_cfg['CHOICE_RATIO']) | |
else: | |
min_ratio = self.random_cfg.get('MIN_RATIO', 48) | |
max_ratio = self.random_cfg.get('MAX_RATIO', 96) | |
ratio = np.random.randint(min_ratio, max_ratio) | |
else: | |
ratio = self.ratio | |
image = cv2.resize(image, (int(w // ratio), int(h // ratio)), | |
interpolation=cv2.INTER_CUBIC) | |
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_NEAREST) | |
assert len(image.shape) < 4 | |
return image | |
def get_config_template(): | |
return dict_to_yaml('ANNOTATORS', | |
__class__.__name__, | |
ColorAnnotator.para_dict, | |
set_name=True) | |