macadeliccc commited on
Commit
c211f28
1 Parent(s): 43ee882

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import random
6
+ import torch
7
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
8
+ import logging
9
+ from logging.handlers import RotatingFileHandler
10
+ import base64
11
+ import io
12
+ import os
13
+ import numpy as np
14
+
15
+ class DetectionRequest(BaseModel):
16
+ image_data: str
17
+ texts: List[List[str]]
18
+
19
+ class DetectionResult(BaseModel):
20
+ detections: List[str]
21
+ image_with_boxes: str
22
+
23
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
24
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
25
+ # Create logs directory if it doesn't exist
26
+ if not os.path.exists('logs'):
27
+ os.makedirs('logs')
28
+
29
+ def draw_bounding_boxes(image: Image, boxes, scores, labels, text_labels):
30
+ draw = ImageDraw.Draw(image)
31
+ width, height = image.size
32
+
33
+ # Define the color bank
34
+ color_bank = ["#0AC2FF", "#47FF0A", "#FF0AC2", "#ADD8E6", "#FF0A47"]
35
+
36
+ # Use default font
37
+ font = ImageFont.load_default()
38
+
39
+ for box, score, label in zip(boxes, scores, labels):
40
+ # Choose a random color
41
+ color = random.choice(color_bank)
42
+
43
+ # Convert the box to a Python list if it's not already
44
+ if isinstance(box, torch.Tensor):
45
+ box = box.tolist()
46
+ elif not isinstance(box, (list, tuple)):
47
+ raise TypeError("Box must be a list or tuple of coordinates.")
48
+
49
+ # Draw the rectangle
50
+ draw.rectangle(box, outline=color, width=2)
51
+
52
+ # Get the text to display
53
+ display_text = f"{text_labels[label]}: {score:.2f}"
54
+
55
+ # Calculate position for the text
56
+ text_position = (box[0], box[1] - 10)
57
+
58
+ # Draw the text
59
+ draw.text(text_position, display_text, fill=color, font=font)
60
+
61
+ return image
62
+
63
+ def detect_objects_logic(image_data, texts):
64
+ try:
65
+ # Decode the base64 image
66
+ image_data_bytes = base64.b64decode(image_data)
67
+ image = Image.open(io.BytesIO(image_data_bytes))
68
+ width, height = image.size
69
+
70
+ inputs = processor(text=texts, images=image, return_tensors="pt")
71
+ outputs = model(**inputs)
72
+
73
+ target_sizes = torch.Tensor([image.size[::-1]])
74
+ results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
75
+
76
+ detection_strings = []
77
+ image_with_boxes = image.copy() # Copy the image only once
78
+
79
+ for i, text_group in enumerate(texts):
80
+ if i >= len(results):
81
+ logging.error(f"Text group index {i} exceeds results length.")
82
+ continue
83
+ logging.info(f"Processing texts: {texts}")
84
+ results_per_group = results[i]
85
+ boxes = results_per_group["boxes"]
86
+ scores = results_per_group["scores"]
87
+ labels = results_per_group["labels"]
88
+
89
+ image_with_boxes = draw_bounding_boxes(image_with_boxes, boxes, scores, labels, text_group)
90
+
91
+ for box, score, label in zip(boxes, scores, labels):
92
+ scaled_box = [round(box[i].item() * (width if i % 2 == 0 else height), 2) for i in range(len(box))]
93
+ detection_string = f"Detected {text_group[label]} with confidence {round(score.item(), 3)} at location {scaled_box}"
94
+ detection_strings.append(detection_string)
95
+
96
+ logging.info("Bounding boxes and labels have been drawn on the image.")
97
+
98
+ return image_with_boxes, detection_strings
99
+
100
+ except IndexError as e:
101
+ logging.error(f"Index error: {e}. Check if the number of text groups matches the model's output.")
102
+ raise e
103
+ except Exception as e:
104
+ logging.error(f"An unexpected error occurred: {e}", exc_info=True)
105
+ raise e
106
+
107
+ def gradio_detect_and_draw(image, text_labels):
108
+ # Check if the image is None
109
+ if image is None:
110
+ raise ValueError("No image was provided.")
111
+
112
+ # Convert the input image to PIL Image if it's a numpy array
113
+ if isinstance(image, np.ndarray):
114
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
115
+
116
+ # Convert PIL Image to base64 for your logic function
117
+ buffered = io.BytesIO()
118
+ image.save(buffered, format="JPEG")
119
+ image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
120
+
121
+ # Process texts input
122
+ text_labels = [text_labels.split(',')] if text_labels else []
123
+
124
+ # Call your detection logic
125
+ processed_image, detections = detect_objects_logic(image_data, text_labels)
126
+
127
+ # Convert the output image to PIL Image if it's a numpy array
128
+ if isinstance(processed_image, np.ndarray):
129
+ processed_image = Image.fromarray(processed_image.astype('uint8'), 'RGB')
130
+
131
+ return processed_image, detections
132
+
133
+
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown("## Owlv2 Object Detection Demo")
136
+ with gr.Row():
137
+ with gr.Column():
138
+ image_input = gr.Image(type="pil", label="Upload or draw an image")
139
+ text_input = gr.Textbox(label="Enter comma-separated labels for detection")
140
+ submit_button = gr.Button("Detect")
141
+ with gr.Column():
142
+ image_output = gr.Image(label="Processed Image")
143
+ text_output = gr.Text(label="Detections")
144
+
145
+
146
+ submit_button.click(
147
+ gradio_detect_and_draw,
148
+ inputs=[image_input, text_input],
149
+ outputs=[image_output, text_output]
150
+ )
151
+ # Add examples
152
+ examples = [
153
+ ["https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", "snowman"],
154
+ ["https://history.iowa.gov/sites/default/files/primary-sources/images/history-education-pss-transportation-centralpark-source.jpg", "taxi,traffic light"],
155
+ ["https://i.pinimg.com/1200x/51/e1/a1/51e1a12517e95725590d3a4b1a7575d7.jpg", "umbrella"]
156
+ ]
157
+ gr.Examples(examples, inputs=[image_input, text_input])
158
+
159
+
160
+ demo.launch()