File size: 2,962 Bytes
aae2aac
 
 
 
 
 
 
 
788697a
aae2aac
 
 
 
afd6f20
aae2aac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630b911
7e2bb82
aae2aac
c043972
797b9ba
7e2bb82
 
 
aae2aac
7e2bb82
630b911
739883b
7c1b1e5
797b9ba
7e2bb82
739883b
7e2bb82
 
 
8f09828
 
797b9ba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from turtle import title
import os 
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
import cv2 
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()

def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
    bbox = np.asarray(bbox)/model_shape
    y1,y2 = bbox[::2] *orig_image_shape[0]
    x1,x2 = bbox[1::2]*orig_image_shape[1]
    return [int(y1),int(x1),int(y2),int(x2)]

def detect_using_clip(image,prompts=[],threshould=0.4):
    model_detections = dict()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)
    detection = outputs.logits[0]  # Assuming class index 0
    for i,prompt in enumerate(prompts):
        predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
        predicted_image = np.where(predicted_image>threshould,255,0)
        # extract countours from the image
        lbl_0 = label(predicted_image)
        props = regionprops(lbl_0)
        model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]

    return model_detections

def display_images(image,detections,prompt='traffic light'):
    H,W = image.shape[:2]
    image_copy = image.copy()
    if prompt not in detections.keys():
        print("prompt not in query ..")
        return image_copy
    for bbox in detections[prompt]:
        cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
    return image_copy


def shot(image, labels_text,selected_categoty):
    print("Labels Text ",labels_text)
    prompts = labels_text.split(',')
    classes = prompts

    print("prompts :",prompts,classes)
    print("Image shape ",image.shape )

    detections  = detect_using_clip(image,prompts=prompts)
    print("detections :",detections)
    print("Ctegory ",selected_categoty)
    return 0

iface = gr.Interface(fn=shot,
                    inputs = ["image","text",gr.Dropdown(classes, label="Category Label",info='Select Categories')],
                    outputs="image",
                    description="Add a picture and a list of labels separated by commas",
                    title="Zero-shot Image Classification with Prompt ",
                    examples=[["images/room.jpg","bed,table,plant"]],
                    # allow_flagging=False, 
                    # analytics_enabled=False,
                )
iface.launch()