Spaces:
Running
Running
File size: 4,223 Bytes
d1bffba 90b20c2 d1bffba 90b20c2 d1d4db7 d1bffba c2e6eeb d1bffba c2e6eeb d1bffba 8dbc829 8753c71 8dbc829 d1bffba 8dbc829 c2e6eeb d367c2f d1bffba 2677815 c2e6eeb d1d4db7 c2e6eeb d367c2f d1bffba 458bcca 8dbc829 95f283f 1f10ad6 8dbc829 95f283f d1bffba d1d4db7 8dbc829 95f283f c2e6eeb d1bffba d367c2f a606ba4 8dbc829 d367c2f d1bffba 18e03d9 4a60e71 d1bffba 839c9a4 d1bffba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch
from torch import nn
import cv2
from matplotlib import pyplot as plt
from segmentation_mask_overlay import overlay_masks
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()
def create_rgb_mask(mask):
color = tuple(np.random.choice(range(0,256), size=3))
gray_3_channel = cv2.merge((mask, mask, mask))
gray_3_channel[mask==255] = color
return gray_3_channel.astype(np.uint8)
def detect_using_clip(image,prompts=[],threshould=0.4):
predicted_masks = list()
inputs = processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
outputs = model(**inputs)
#preds = outputs.logits.unsqueeze(1)
preds = nn.functional.interpolate(
outputs.logits.unsqueeze(1),
size=(image.shape[0], image.shape[1]),
mode="bilinear"
)
threshold = 0.1
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
# Initialize a dummy "unlabeled" mask with the threshold
flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)
flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds
# Get the top mask index for each pixel
inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
predicted_masks = []
for i in range(1, len(prompts)+1):
mask = np.where(inds==i,255,0)
predicted_masks.append(mask)
return predicted_masks
def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
alpha = 0.7
image_resize = cv2.resize(image,(352,352))
resize_image_copy = image_resize.copy()
# for mask_image in predicted_images:
# resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)
return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)
def shot(alpha,beta,image,labels_text):
print(labels_text)
if "," in labels_text:
prompts = labels_text.split(',')
else:
prompts = [labels_text]
print(prompts)
prompts = list(map(lambda x: x.strip(),prompts))
mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)]
cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
predicted_masks = detect_using_clip(image,prompts=prompts)
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
category_image = overlay_masks(image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=alpha,beta=beta)
return category_image
iface = gr.Interface(fn=shot,
inputs = [
gr.Slider(0.1, 1, value=0.3, step=0.1 , label="alpha", info="Choose between 0.1 to 1"),
gr.Slider(0.1, 1, value=0.7, step=0.1, label="beta", info="Choose between 0.1 to 1"),
"image",
"text"
],
outputs = "image",
description ="Add an Image and labels to be detected separated by commas(atleast 2)",
title = "Zero-shot Image Segmentation with Prompt",
examples=[
[0.4,0.7,"images/room.jpg","chair, plant , flower pot , white cabinet , paintings , decorative plates , books"],
[0.4,0.7,"images/seats.jpg","door,table,chairs"],
[0.3,0.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"],
[0.4,0.7,"images/dashcam.jpeg","car,sky,road,grassland,trees"]
],
# allow_flagging=False,
# analytics_enabled=False,
)
iface.launch()
|