File size: 3,202 Bytes
9eb9e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import time
import numpy as np
import json
import torch
import gradio as gr

# Model setup
model_id = "IDEA-Research/grounding-dino-base"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
text="Container. Bottle. Fruit. Vegetable. Packet."
iou_threshold=0.4
box_threshold=0.3
score_threshold=0.3

# Function to detect objects in an image and return a JSON with count, class, box, and score
def detect_objects(image):

    # Prepare inputs for the model
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Post-process results
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=box_threshold,
        text_threshold=score_threshold,
        target_sizes=[image.size[::-1]]
    )

    # Function to calculate IoU (Intersection over Union)
    def iou(box1, box2):
        x1, y1, x2, y2 = box1
        x1_2, y1_2, x2_2, y2_2 = box2

        # Calculate intersection area
        inter_x1 = max(x1, x1_2)
        inter_y1 = max(y1, y1_2)
        inter_x2 = min(x2, x2_2)
        inter_y2 = min(y2, y2_2)

        if inter_x2 < inter_x1 or inter_y2 < inter_y1:
            return 0.0  # No intersection

        intersection_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)

        # Calculate union area
        area1 = (x2 - x1) * (y2 - y1)
        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
        union_area = area1 + area2 - intersection_area

        return intersection_area / union_area

    # Filter out overlapping boxes using NMS (Non-Maximum Suppression)
    filtered_boxes = []
    filtered_labels = []
    filtered_scores = []

    for i, (box, label, score) in enumerate(zip(results[0]['boxes'], results[0]['labels'], results[0]['scores'])):
        keep = True
        for j, (box2, label2, score2) in enumerate(zip(filtered_boxes, filtered_labels, filtered_scores)):
            # If IoU is above the threshold, discard the box
            if iou(box.tolist(), box2) > iou_threshold:
                keep = False
                break
        if keep:
            filtered_boxes.append(box.tolist())
            filtered_labels.append(label)
            filtered_scores.append(score.item())

    # Prepare the output in the requested format
    output = {
        "count": len(filtered_boxes),
        "class": filtered_labels,
        "box": filtered_boxes,
        "score": filtered_scores
    }

    return json.dumps(output)

# Define Gradio input and output components
image_input = gr.Image(type="pil")

# Create the Gradio interface
demo = gr.Interface(
    fn=detect_objects,
    inputs=image_input,
    outputs='text',
    title="Frshness prediction",
    description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes."
)

demo.launch(share=True)