File size: 5,134 Bytes
9949d8d
 
 
20ddd4d
9949d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a9758c
9949d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
import matplotlib.pyplot as plt
import gradio as gr

import os
import io
import numpy as np
import torch
import cv2
from PIL import Image

def fig2img(fig): 
    buf = io.BytesIO() 
    fig.savefig(buf) 
    buf.seek(0) 
    img = Image.open(buf) 
    return img 


def plot(
    annotations,
    prompt_process,
    bbox=None,
    points=None,
    point_label=None,
    mask_random_color=True,
    better_quality=True,
    retina=False,
    with_contours=True,
):
    """
    Plots annotations, bounding boxes, and points on images and saves the output.

    Args:
        annotations (list): Annotations to be plotted.
        output (str or Path): Output directory for saving the plots.
        bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
        points (list, optional): Points to be plotted. Defaults to None.
        point_label (list, optional): Labels for the points. Defaults to None.
        mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
        better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
            Defaults to True.
        retina (bool, optional): Whether to use retina mask. Defaults to False.
        with_contours (bool, optional): Whether to plot contours. Defaults to True.
    """

    # pbar = TQDM(annotations, total=len(annotations))
    for ann in annotations:
        result_name = os.path.basename(ann.path)
        image = ann.orig_img[..., ::-1]  # BGR to RGB
        original_h, original_w = ann.orig_shape
        # For macOS only
        # plt.switch_backend('TkAgg')
        fig = plt.figure(figsize=(original_w / 100, original_h / 100))
        # Add subplot with no margin.
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.imshow(image)

        if ann.masks is not None:
            masks = ann.masks.data
            if better_quality:
                if isinstance(masks[0], torch.Tensor):
                    masks = np.array(masks.cpu())
                for i, mask in enumerate(masks):
                    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
                    masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))

            prompt_process.fast_show_mask(
                masks,
                plt.gca(),
                random_color=mask_random_color,
                bbox=bbox,
                points=points,
                pointlabel=point_label,
                retinamask=retina,
                target_height=original_h,
                target_width=original_w,
            )

            if with_contours:
                contour_all = []
                temp = np.zeros((original_h, original_w, 1))
                for i, mask in enumerate(masks):
                    mask = mask.astype(np.uint8)
                    if not retina:
                        mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
                    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
                    contour_all.extend(iter(contours))
                cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
                color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
                contour_mask = temp / 255 * color.reshape(1, 1, -1)
                plt.imshow(contour_mask)

        # Save the figure
        # save_path = Path(output) / result_name
        # save_path.parent.mkdir(exist_ok=True, parents=True)
        plt.axis("off")
        # plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
        plt.close()
        # pbar.set_description(f"Saving {result_name} to {save_path}")

        return fig2img(fig)

# Create a FastSAM model
model = FastSAM("FastSAM-s.pt")  # or FastSAM-x.pt

def generateOutput(source):
    everything_results = model(source, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
    # Prepare a Prompt Process object
    prompt_process = FastSAMPrompt(source, everything_results, device="cpu")
    # Everything prompt
    results = prompt_process.everything_prompt()
    
    outputimage = plot(annotations=results, prompt_process=prompt_process)

    return(outputimage)

title = "FastSAM Inference Trials"
description = "Shows the FastSAM related Inference Trials"
examples = [["941398-beautiful-farm-animals-wallpaper-2000x1402-for-meizu.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]]
demo = gr.Interface(
    generateOutput, 
    inputs = [
        gr.Image(width=256, height=256, label="Input Image"), 
        ], 
    outputs = [
        gr.Image(width=256, height=256, label="Output"),
        ],
    title = title,
    description = description,
    examples = examples,
    cache_examples=False
)
demo.launch()