mrtlive's picture
added some segmentation functions
79b5495
raw
history blame
1.85 kB
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
#setup model
sam_checkpoint = "sam_vit_h_4b8939.pth"
device = "cuda"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
def segment_image(image):
masks = mask_generator.generate(image)
plt.clf()
ppi = 100
height, width, _ = image.shape
plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
output = cv2.imread('output.png')
return Image.fromarray(output)
with gr.blocks() as demo:
gr.MArkdown("## Segment-anything Demo")
with gr.Row():
image_input = gr.Image()
image_output = gr.Image()
segment_image_button = gr.Button("Segment Image")
segment_image_button.click(segment_image, image_input, image_output)
demo.launch()