File size: 2,356 Bytes
ff605cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import torch
import cv2
import traceback
import numpy as np
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained('facebook/sam-vit-huge')
processor = SamProcessor.from_pretrained('facebook/sam-vit-huge')
def set_predictor(image):
"""
Creates a Sam predictor object based on a given image and model.
"""
device = 'cpu'
inputs = processor(image, return_tensors='pt').to(device)
image_embedding = model.get_image_embeddings(inputs['pixel_values'])
return [image, image_embedding, 'Done']
def get_polygon(points, image, image_embedding):
"""
Returns the points of the polygon given a bounding box and a prediction
made by Sam, or if an exception was triggered, it returns such exception.
"""
points = [int(w) for w in points.split(',')]
device = 'cpu'
inputs = processor(image, input_boxes=[points], return_tensors="pt").to(device)
# pop the pixel_values as they are not neded
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embedding})
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask = masks[0].squeeze().numpy()
img = mask.astype(np.uint8)[0]
contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
points = contours[0]
polygon = []
for point in points:
for x, y in point:
polygon.append([int(x), int(y)])
return polygon
with gr.Blocks() as app:
image = gr.State()
embedding = gr.State()
with gr.Tab('Get embedding'):
input_image = gr.Image(label='Image')
output_status = gr.Textbox(label='Status')
predictor_button = gr.Button('Send Image')
with gr.Tab('Get points'):
bbox = gr.Textbox(label="bbox")
polygon = [gr.Textbox(label='Polygon')]
points_button = gr.Button('Send bounding box')
predictor_button.click(
set_predictor,
input_image,
[image, embedding, output_status],
)
points_button.click(
get_polygon,
[bbox, image, embedding],
polygon,
)
app.queue()
app.launch(debug=True) |