Spaces:
Runtime error
Runtime error
File size: 4,918 Bytes
68a69f9 04c9e40 ce088ab 04c9e40 68a69f9 0ddd61a 68a69f9 0ddd61a 68a69f9 ce088ab 68a69f9 ce088ab 68a69f9 ce088ab 68a69f9 aeaceee 68a69f9 aeaceee c7ab35b 68a69f9 aeaceee 68a69f9 ce088ab 68a69f9 ce088ab 68a69f9 aeaceee ce088ab 68a69f9 ce088ab 68a69f9 ce088ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
import os
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
# Hence, installing detectron2 this way when using Gradio HF spaces.
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
def load_model_and_processor(model_ckpt: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
model.eval()
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
return model, image_preprocessor
def load_default_ckpt(segmentation_task: str):
if segmentation_task == "semantic":
default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
elif segmentation_task == "instance":
default_ckpt = "facebook/mask2former-swin-small-coco-instance"
else:
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
return default_ckpt
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
for res in seg_info:
res['category_id'] = res.pop('label_id')
pred_class = res['category_id']
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
res['isthing'] = bool(isthing)
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
out = visualizer.draw_panoptic_seg_predictions(
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
)
output_img = Image.fromarray(out.get_image())
return output_img
def draw_semantic_segmentation(segmentation_map, image, palette):
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
color_segmentation_map[segmentation_map - 1 == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def visualize_instance_seg_mask(mask, input_image):
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
labels = np.unique(mask)
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for label, color in label2color.items():
color_segmentation_map[mask - 1 == label, :] = color
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def predict_masks(input_img_path: str, segmentation_task: str):
#load model and image processor
default_ckpt = load_default_ckpt(segmentation_task)
model, image_processor = load_model_and_processor(default_ckpt)
## pass input image through image processor
image = Image.open(input_img_path)
inputs = image_processor(images=image, return_tensors="pt")
## pass inputs to model for prediction
with torch.no_grad():
outputs = model(**inputs)
# pass outputs to processor for postprocessing
if segmentation_task == "semantic":
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result.cpu().numpy()
palette = ade_palette()
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
output_heading = "Semantic Segmentation Output"
elif segmentation_task == "instance":
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
output_heading = "Instance Segmentation Output"
else:
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result["segmentation"]
seg_info = result['segments_info']
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
output_heading = "Panoptic Segmentation Output"
return output_result, output_heading
|