Spaces:
Running
Running
from functools import lru_cache | |
from typing import List | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from hbutils.color import rnd_colors | |
def _get_complete_classes(): | |
from nudenet import open_model_session | |
_, classes = open_model_session() | |
return classes | |
def _get_color_map(): | |
_all_classes = _get_complete_classes() | |
colors = rnd_colors(len(_get_complete_classes()), rnd=0) | |
return {cls_: (str(c),) for c, cls_ in zip(colors, _all_classes)} | |
CLS_MAP = { | |
'EXPOSED_BREAST_F': 'nipple', | |
'EXPOSED_GENITALIA_F': 'pussy', | |
'EXPOSED_GENITALIA_M': 'penis', | |
'EXPOSED_ANUS': 'anus', | |
} | |
def plot_detection(pil_img: Image.Image, detection: List): | |
plt.tight_layout() | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
_color_map = _get_color_map() | |
for item in detection: | |
score = item['score'] | |
xmin, ymin, xmax, ymax = item['box'] | |
class_ = item['label'] | |
box_color, = _color_map[class_] | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=box_color, linewidth=3)) | |
text = f'{CLS_MAP.get(class_, class_)}: {score * 100:.2f}%' | |
ax.text(xmin, ymin, text, fontsize=8, bbox=dict(facecolor=box_color, alpha=0.5)) | |