File size: 2,663 Bytes
cc7fbfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c97f39d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import requests
from transformers import SamModel, SamProcessor
import cv2

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def mask_2_dots(mask):
    gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    _, thresh = cv2.threshold(gray, 127, 255, 0)
    kernel = np.ones((5,5),np.uint8)
    closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
    contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    points = []
    for contour in contours:
        moments = cv2.moments(contour)
        cx = int(moments['m10']/moments['m00'])
        cy = int(moments['m01']/moments['m00'])
        points.append([cx, cy])
    return [points]

def main_func(inputs):
    dots = inputs['mask']
    points = mask_2_dots(dots)

    image_input = inputs['image']
    image_input = Image.fromarray(image_input)

    inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
    # Forward pass
    outputs = model(**inputs)

    # Postprocess outputs
    draw = ImageDraw.Draw(image_input)
    for point in points[0]:
        draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")


    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    #scores = outputs.iou_scores

    mask = masks[0].squeeze(0).numpy().transpose(1, 2, 0)

    pred_masks = [image_input]
    for i in range(mask.shape[2]):
        #mask[:,:,i] = mask[:,:,i] * scores[0][i].item()
        pred_masks.append(Image.fromarray((mask[:,:,i] * 255).astype(np.uint8)))

    return pred_masks


with gr.Blocks() as demo:
    gr.Markdown("# Demo to run Segment Anything base model")
    gr.Markdown("""This app uses the [Segment Anything](https://huggingface.co/facebook/sam-vit-base) model from Meta to get a mask from a points in an image.
    Currently it only works for creating dots for one object. But, I'm planning to add extra features to make it work for multiple objects.
    The output shows the image with the dots then the 3 predicted masks.
    """)
    with gr.Tab("Flip Image"):
        with gr.Row():
            image_input = gr.Image(tool='sketch')
            image_output = gr.Gallery()
        
        image_button = gr.Button("Segment Image")

    image_button.click(main_func, inputs=image_input, outputs=image_output)

demo.launch()