File size: 4,054 Bytes
420fa3e
78d66d3
fbd12ae
420fa3e
4ce6c6b
c73b59b
7788a89
c73b59b
78d66d3
4ce6c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c73b59b
5017de6
c73b59b
 
420fa3e
5017de6
db79c47
 
78d66d3
c73b59b
 
a4a6a96
c73b59b
 
c30e671
4ce6c6b
420fa3e
c73b59b
 
 
 
 
 
5017de6
c73b59b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78d66d3
c73b59b
 
 
 
 
78d66d3
c73b59b
 
78d66d3
c73b59b
 
 
 
 
 
a4a6a96
c73b59b
78d66d3
c73b59b
 
 
78d66d3
a4a6a96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5017de6
a4a6a96
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from transformers import pipeline, SamModel, SamProcessor
import torch
import os
import numpy as np
import spaces
import gradio as gr
import shutil
from PIL import Image

def find_cuda():
    # Check if CUDA_HOME or CUDA_PATH environment variables are set
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')

    if cuda_home and os.path.exists(cuda_home):
        return cuda_home

    # Search for the nvcc executable in the system's PATH
    nvcc_path = shutil.which('nvcc')

    if nvcc_path:
        # Remove the 'bin/nvcc' part to get the CUDA installation path
        cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
        return cuda_path

    return None

cuda_path = find_cuda()

if cuda_path:
    print(f"CUDA installation found at: {cuda_path}")
else:
    print("CUDA installation not found")
    
# check if cuda is available
device = "cuda" if torch.cuda.is_available() else "cpu"

# we initialize model and processor
checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-huge")

def apply_mask(image, mask, color):
    """Apply a mask to an image with a specific color."""
    for c in range(3):  # Iterate over RGB channels
        image[:, :, c] = np.where(mask, color[c], image[:, :, c])
    return image

@spaces.GPU
def query(image, texts, threshold):
    texts = texts.split(",")
    predictions = detector(
        image,
        candidate_labels=texts,
        threshold=threshold
    )
    
    image = np.array(image).copy()
    
    colors = [
        (255, 0, 0),  # Red
        (0, 255, 0),  # Green
        (0, 0, 255),  # Blue
        (255, 255, 0),  # Yellow
        (255, 165, 0),  # Orange
        (255, 0, 255)  # Magenta
    ]
    
    for i, pred in enumerate(predictions):
        score = pred["score"]
        if score > 0.5:
            box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
                   round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]

            inputs = sam_processor(
                image,
                input_boxes=[[[box]]],
                return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                outputs = sam_model(**inputs)

            mask = sam_processor.image_processor.post_process_masks(
                outputs.pred_masks.cpu(),
                inputs["original_sizes"].cpu(),
                inputs["reshaped_input_sizes"].cpu()
            )[0][0][0].numpy()
            
            color = colors[i % len(colors)]  # cycle through colors
            image = apply_mask(image, mask > 0.5, color)

    result_image = Image.fromarray(image)
    
    return result_image

title = """
# RobustSAM
"""

description = """
**Welcome to RobustSAM by Snap Research.**

This Space uses **RobustSAM**, a robust version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities.

Thanks to its integration with **OWLv2**, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality.

Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results.

For better results, please check the [GitHub repository](https://github.com/robustsam/RobustSAM).
"""

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    
    gr.Interface(
        query,
        inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label="Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
        outputs=gr.Image(type="pil", label="Segmented Image"),
        examples=[
        ["./blur.jpg", "insect", 0.1],
        ["./lowlight.jpg", "bus, window", 0.1]
        ],
        cache_examples=True
    )

demo.launch()