Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import random | |
import spaces | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import onnxruntime | |
import torch | |
import torchvision.transforms.functional as F | |
from huggingface_hub import hf_hub_download | |
from PIL import Image, ImageColor | |
from torchvision.io import read_image | |
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights | |
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks | |
# Load pre-trained model transformations. | |
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT | |
transforms = weights.transforms() | |
def fix_category_id(cat_ids: list): | |
# Define the excluded category ids and the remaining ones | |
excluded_indices = {2, 12, 16, 19, 20} | |
remaining_categories = list(set(range(27)) - excluded_indices) | |
# Create a dictionary that maps new IDs to old(original) IDs | |
new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories)) | |
return [new_id_to_org_id[i-1]+1 for i in cat_ids] | |
def process_categories() -> tuple: | |
""" | |
Load and process category information from a JSON file. | |
Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and | |
`category_id_to_color` maps category IDs to a randomly sampled RGB color. | |
Returns: | |
tuple: A tuple containing two dictionaries: | |
- `category_id_to_name`: a dictionary mapping category IDs to their names. | |
- `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color. | |
""" | |
# Load raw categories from JSON file | |
with open("categories.json") as fp: | |
categories = json.load(fp) | |
# Map category IDs to names | |
category_id_to_name = {d["id"]: d["name"] for d in categories} | |
# Set the seed for the random sampling operation | |
random.seed(42) | |
# Get a list of all the color names in the PIL colormap | |
color_names = list(ImageColor.colormap.keys()) | |
# Sample 46 unique colors from the list of color names | |
sampled_colors = random.sample(color_names, 46) | |
# Convert the color names to RGB values | |
rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors] | |
# Map category IDs to colors | |
category_id_to_color = { | |
category["id"]: color for category, color in zip(categories, rgb_colors) | |
} | |
return category_id_to_name, category_id_to_color | |
def draw_predictions( | |
boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold | |
): | |
""" | |
Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions | |
with scores above the `score_threshold` will be included, and masks with probabilities exceeding the | |
`proba_threshold` will be displayed. | |
Args: | |
- boxes: numpy.ndarray - an array of bounding box coordinates. | |
- labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box. | |
- scores: numpy.ndarray - an array of confidence scores for each bounding box. | |
- masks: numpy.ndarray - an array of binary masks for each bounding box. | |
- img: PIL.Image.Image - the input image. | |
- model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+". | |
- score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions. | |
- proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions. | |
Returns: | |
- A list of strings, each representing the path to an image file containing the input image with a different | |
set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores). | |
""" | |
imgs_list = [] | |
# Map label IDs to names and colors | |
label_id_to_name, label_id_to_color = process_categories() | |
# Filter out predictions using thresholds | |
labels_id = labels[scores > score_threshold].tolist() | |
if model_name == "facere+": | |
labels_id = fix_category_id(labels_id) | |
# models output is in range: [1,class_id+1], hence re-map to: [0,class_id] | |
labels = [label_id_to_name[int(i) - 1] for i in labels_id] | |
masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8) | |
boxes = boxes[scores > score_threshold] | |
# Draw masks to input image and save | |
img_masks = draw_segmentation_masks( | |
image=img, | |
masks=torch.from_numpy(masks.squeeze(1).astype(bool)), | |
alpha=0.9, | |
colors=[label_id_to_color[int(i) - 1] for i in labels_id], | |
) | |
img_masks = F.to_pil_image(img_masks) | |
img_masks.save("img_masks.png") | |
imgs_list.append("img_masks.png") | |
# Draw bboxes to input image and save | |
img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4) | |
img_bbox = F.to_pil_image(img_bbox) | |
img_bbox.save("img_bbox.png") | |
imgs_list.append("img_bbox.png") | |
# Save masks with their bbox labels & bbox scores | |
for col, (mask, label, score) in enumerate(zip(masks, labels, scores)): | |
mask = Image.fromarray(mask.squeeze()) | |
plt.imshow(mask) | |
plt.axis("off") | |
plt.title(f"{label}: {score:.2f}", fontsize=9) | |
plt.savefig(f"mask-{col}.png") | |
plt.close() | |
imgs_list.append(f"mask-{col}.png") | |
return imgs_list | |
def inference(image, model_name, mask_threshold, bbox_threshold): | |
""" | |
Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a | |
figure, which will be shown in the Gradio app. | |
""" | |
# Load image. | |
img = read_image(image) | |
# Apply original transformation to the image. | |
img_transformed = transforms(img) | |
# Download model | |
path_onnx = hf_hub_download( | |
repo_id="rizavelioglu/fashionfail", | |
filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx" | |
) | |
# Create an inference session. | |
ort_session = onnxruntime.InferenceSession( | |
path_onnx, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] | |
) | |
# compute ONNX Runtime output prediction | |
ort_inputs = { | |
ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy() | |
} | |
ort_outs = ort_session.run(None, ort_inputs) | |
boxes, labels, scores, masks = ort_outs | |
imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name, | |
score_threshold=bbox_threshold, proba_threshold=mask_threshold | |
) | |
return imgs_list | |
title = "Facere - Demo" | |
description = r"""This is the demo of the paper <a href="https://arxiv.org/abs/2404.08582">FashionFail: Addressing | |
Failure Cases in Fashion Object Detection and Segmentation</a>. <br>Upload your image and choose the model for inference | |
from the dropdown menu—either `Facere` or `Facere+` <br> Check out the <a | |
href="https://rizavelioglu.github.io/fashionfail/">project page</a> for more information.""" | |
article = r""" | |
Example images are sampled from the `FashionFail-test` set, which the models did not see during training. | |
<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and | |
a citation: | |
``` | |
@inproceedings{velioglu2024fashionfail, | |
author = {Velioglu, Riza and Chan, Robin and Hammer, Barbara}, | |
title = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation}, | |
journal = {IJCNN}, | |
eprint = {2404.08582}, | |
year = {2024}, | |
} | |
``` | |
""" | |
examples = [ | |
["adi_103_6.jpg", "facere", 0.5, 0.7], | |
["adi_103_6.jpg", "facere+", 0.5, 0.7], | |
["adi_1201_2.jpg", "facere", 0.5, 0.7], | |
["adi_1201_2.jpg", "facere+", 0.5, 0.7], | |
["adi_2149_5.jpg", "facere", 0.5, 0.7], | |
["adi_2149_5.jpg", "facere+", 0.5, 0.7], | |
["adi_5476_3.jpg", "facere", 0.5, 0.7], | |
["adi_5476_3.jpg", "facere+", 0.5, 0.7], | |
["adi_5641_4.jpg", "facere", 0.5, 0.7], | |
["adi_5641_4.jpg", "facere+", 0.5, 0.7] | |
] | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Image(type="filepath", label="input"), | |
gr.Dropdown(["facere", "facere+"], value="facere", label="Models"), | |
gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for " | |
"filtering out " | |
"low-probability (" | |
"pixel-wise) mask " | |
"predictions"), | |
gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for " | |
"filtering out " | |
"low-scoring bbox " | |
"predictions") | |
], | |
outputs=gr.Gallery(label="output", preview=True, height=500), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
cache_examples=True, | |
examples_per_page=6 | |
) | |
if __name__ == "__main__": | |
demo.launch() | |