auto_image_censor / visual.py
narugo1992
dev(narugo): move the process inside
89ce0c1
raw
history blame contribute delete
No virus
1.26 kB
from functools import lru_cache
from typing import List
import matplotlib.pyplot as plt
from PIL import Image
from hbutils.color import rnd_colors
@lru_cache()
def _get_complete_classes():
from nudenet import open_model_session
_, classes = open_model_session()
return classes
@lru_cache()
def _get_color_map():
_all_classes = _get_complete_classes()
colors = rnd_colors(len(_get_complete_classes()), rnd=0)
return {cls_: (str(c),) for c, cls_ in zip(colors, _all_classes)}
CLS_MAP = {
'EXPOSED_BREAST_F': 'nipple',
'EXPOSED_GENITALIA_F': 'pussy',
'EXPOSED_GENITALIA_M': 'penis',
'EXPOSED_ANUS': 'anus',
}
def plot_detection(pil_img: Image.Image, detection: List):
plt.tight_layout()
plt.imshow(pil_img)
ax = plt.gca()
_color_map = _get_color_map()
for item in detection:
score = item['score']
xmin, ymin, xmax, ymax = item['box']
class_ = item['label']
box_color, = _color_map[class_]
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=box_color, linewidth=3))
text = f'{CLS_MAP.get(class_, class_)}: {score * 100:.2f}%'
ax.text(xmin, ymin, text, fontsize=8, bbox=dict(facecolor=box_color, alpha=0.5))