File size: 4,297 Bytes
f489c3f
 
91171fa
f489c3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0502b57
f489c3f
 
 
8e44bc6
f489c3f
 
 
 
7ac90a9
 
 
 
 
 
 
 
0502b57
 
 
8e44bc6
 
0502b57
 
8e44bc6
0502b57
 
 
7ac90a9
0502b57
3166f93
f489c3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import torch
import numpy as np

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")


def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    # normalize the mask
    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)
    return mask


def get_masks(prompts, img, threhsold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt)
        mask = mask > threhsold
        masks.append(mask)
    return masks


def extract_image(img, pos_prompts, neg_prompts, threshold, alpha_value):
    positive_masks = get_masks(pos_prompts, img, threshold)
    negative_masks = get_masks(neg_prompts, img, threshold)

    # combine masks into one masks, logic OR
    pos_mask = np.any(np.stack(positive_masks), axis=0)
    neg_mask = np.any(np.stack(negative_masks), axis=0)
    final_mask = pos_mask & ~neg_mask

    # create the final mask image
    final_mask_img = Image.fromarray((final_mask * 255).astype(np.uint8), "L")

    # create an RGBA version of the original image and the final mask image
    img_rgba = img.convert("RGBA")
    mask_rgba = Image.new("RGBA", img.size, (0, 0, 0, 0))
    mask_rgba.paste(final_mask_img, (0, 0), final_mask_img)

    # apply alpha value to final_mask
    alpha_mask = Image.fromarray((final_mask * 255 * alpha_value).astype(np.uint8), "L")
    inverse_alpha_mask = Image.fromarray(((1 - final_mask) * 255 * alpha_value).astype(np.uint8), "L")

    # extract the final image
    output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
    output_image.paste(img, mask=alpha_mask)

    # extract the inverse_mask
    output_inverse_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
    output_inverse_image.paste(img, mask=inverse_alpha_mask)

    return output_image, alpha_mask, inverse_alpha_mask


title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"


with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
    gr.Markdown(article)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            positive_prompts = gr.Textbox(
                label="Please describe what you want to identify (comma separated)"
            )
            negative_prompts = gr.Textbox(
                label="Please describe what you want to ignore (comma separated)"
            )

            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            btn_process = gr.Button(label="Process")

        with gr.Column():
            output_image = gr.Image(label="Result")
            output_mask = gr.Image(label="Mask")
            inverse_mask = gr.Image(label="Inverse")

    btn_process.click(
        extract_image,
        inputs=[
            input_image,
            positive_prompts,
            negative_prompts,
            input_slider_T,
        ],
        outputs=[output_image, output_mask, inverse_mask],
        api_name="mask"
    )


demo.launch()