File size: 1,813 Bytes
4fba7a2 f254011 4fba7a2 480594f f254011 480594f f254011 a49b93f 480594f a49b93f 480594f a49b93f f254011 480594f a49b93f f254011 a49b93f 480594f a49b93f f254011 a49b93f f254011 a49b93f 480594f f254011 480594f f254011 a49b93f f254011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-coco")
# feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
# model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")
def visualize_instance_seg_mask(mask, id2label):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
image_total_pixels = mask.shape[0] * mask.shape[1]
label_ids = np.unique(mask)
id2color = {id: (np.random.randint(0, 2), np.random.randint(0, 256), np.random.randint(0, 256)) for id in label_ids}
id2count = {id: 0 for id in label_ids}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = id2color[mask[i, j]]
id2count[mask[i, j]] = id2count[mask[i, j]] + 1
image = image / 255
label2count = {id2label[id]: id2count[id] / image_total_pixels for id in label_ids}
print(label2count)
return image
def query_image(img):
img_size = (img.shape[0], img.shape[1])
inputs = feature_extractor(images=img, return_tensors="pt")
outputs = model(**inputs)
results = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
results = visualize_instance_seg_mask(results.numpy(), model.config.id2label)
return results
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="maskformer-swin-large-coco results",
allow_flagging="never",
analytics_enabled=None
)
demo.launch(show_api=False)
|