awpbash commited on
Commit
c7935e1
Β·
1 Parent(s): 01cca9a
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +269 -0
  3. images/image1.png +3 -0
  4. images/image2.png +3 -0
  5. images/image3.png +3 -0
  6. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import io
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ import sys
8
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
9
+ from geoclip import GeoCLIP
10
+ import tempfile
11
+ import os
12
+
13
+ # Set device
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Global model variables
17
+ processor, gdino_model, ocr_model, geo_model = None, None, None, None
18
+
19
+ def load_image(image_pil):
20
+ """
21
+ Converts a PIL image to a BGR NumPy array.
22
+ """
23
+ img_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
24
+ if img_bgr is None:
25
+ raise ValueError("Could not decode image.")
26
+ return img_bgr
27
+
28
+ def load_gdino():
29
+ """
30
+ Loads and returns the Grounding DINO model and processor.
31
+ """
32
+ global processor, gdino_model
33
+ if gdino_model is None:
34
+ print("Loading Grounding DINO model...")
35
+ model_id = "IDEA-Research/grounding-dino-base"
36
+ processor = AutoProcessor.from_pretrained(model_id)
37
+ gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
38
+ print("Grounding DINO model loaded.")
39
+ return processor, gdino_model
40
+
41
+ def load_geoclip():
42
+ """
43
+ Loads and returns the GeoCLIP model.
44
+ """
45
+ global geo_model
46
+ if geo_model is None:
47
+ print("Loading GeoCLIP model...")
48
+ geo_model = GeoCLIP()
49
+ print("GeoCLIP model loaded.")
50
+ return geo_model
51
+
52
+ def detect_gdino(img_pil, processor, model, box_threshold, text_threshold, queries):
53
+ """
54
+ Performs object detection using Grounding DINO.
55
+ """
56
+ if not queries:
57
+ return np.empty((0, 4), dtype=int)
58
+
59
+ text = ". ".join([q.lower() for q in queries]) + "."
60
+ inputs = processor(images=img_pil, text=text, return_tensors="pt").to(device)
61
+
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+
65
+ results = processor.post_process_grounded_object_detection(
66
+ outputs,
67
+ inputs.input_ids,
68
+ box_threshold=box_threshold,
69
+ text_threshold=text_threshold,
70
+ target_sizes=[img_pil.size[::-1]]
71
+ )
72
+
73
+ boxes = results[0]["boxes"].cpu().numpy()
74
+ return boxes
75
+
76
+ def try_ocr():
77
+ """
78
+ Attempts to load PaddleOCR. Returns the model or None if it fails.
79
+ """
80
+ global ocr_model
81
+ if ocr_model is None:
82
+ try:
83
+ from paddleocr import PaddleOCR
84
+ print("Loading PaddleOCR...")
85
+ ocr_model = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
86
+ print("PaddleOCR loaded.")
87
+ except ImportError:
88
+ print("PaddleOCR not found. Skipping OCR detection.")
89
+ except Exception as e:
90
+ print(f"Error loading PaddleOCR: {e}. Skipping OCR detection.")
91
+ return ocr_model
92
+
93
+ def detect_ocr_boxes(image_bgr, ocr):
94
+ """
95
+ Detects text bounding boxes using PaddleOCR.
96
+ """
97
+ results = ocr.ocr(image_bgr, cls=True)
98
+ boxes = []
99
+ if results and results[0]:
100
+ for line in results[0]:
101
+ points = line[0]
102
+ if points:
103
+ x_coords = [p[0] for p in points]
104
+ y_coords = [p[1] for p in points]
105
+ x_min, x_max = min(x_coords), max(x_coords)
106
+ y_min, y_max = min(y_coords), max(y_coords)
107
+ boxes.append([x_min, y_min, x_max, y_max])
108
+ return np.array(boxes)
109
+
110
+ def union_masks(image_shape, box_lists):
111
+ """
112
+ Creates a single mask from a list of bounding box arrays.
113
+ """
114
+ mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
115
+ for boxes in box_lists:
116
+ if boxes is not None and len(boxes) > 0:
117
+ for box in boxes:
118
+ x_min, y_min, x_max, y_max = [int(v) for v in box]
119
+ mask[y_min:y_max, x_min:x_max] = 255
120
+ return mask
121
+
122
+ def redact(image, mask, method="blur", blur_ksize=151, mosaic_scale=0.06):
123
+ """
124
+ Applies the chosen redaction method (blur or pixelate) to the image based on the mask.
125
+ """
126
+ if method == "blur":
127
+ if blur_ksize % 2 == 0:
128
+ blur_ksize += 1
129
+ blurred = cv2.GaussianBlur(image, (blur_ksize, blur_ksize), 0)
130
+ return np.where(mask[:, :, None] == 255, blurred, image)
131
+ elif method == "pixelate":
132
+ h, w = image.shape[:2]
133
+ small_h = int(h * mosaic_scale)
134
+ small_w = int(w * mosaic_scale)
135
+ if small_h <= 0: small_h = 1
136
+ if small_w <= 0: small_w = 1
137
+
138
+ resized = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
139
+ pixelated = cv2.resize(resized, (w, h), interpolation=cv2.INTER_NEAREST)
140
+ return np.where(mask[:, :, None] == 255, pixelated, image)
141
+ return image
142
+
143
+ # Gradio processing function
144
+ def process_image(image_pil, redaction_targets, redaction_method):
145
+ """
146
+ Main function for the Gradio interface.
147
+
148
+ Args:
149
+ image_pil (PIL.Image): The input image.
150
+ redaction_targets (list): A list of strings representing the items to redact.
151
+ redaction_method (str): The method to use for redaction ('blur' or 'pixelate').
152
+
153
+ Returns:
154
+ tuple: A tuple containing the path to the redacted image file and a text string with detection results.
155
+ """
156
+
157
+ # Load models
158
+ processor, gdino_model = load_gdino()
159
+ ocr_model = try_ocr()
160
+ geo_model = load_geoclip()
161
+
162
+ if image_pil is None:
163
+ return None, "Please upload an image."
164
+
165
+ img_bgr = load_image(image_pil)
166
+
167
+ # Define queries based on checkboxes
168
+ queries = []
169
+ if "Flags" in redaction_targets:
170
+ queries.extend(["flag", "country flags", "state flags"])
171
+ if "Signs" in redaction_targets:
172
+ queries.extend(["street name sign", "road name sign"])
173
+ if 'Faces' in redaction_targets:
174
+ queries.extend(["human faces", "faces", "people faces", "child faces", "human head", "people head"])
175
+ if 'Building/Flat Numbers' in redaction_targets:
176
+ queries.extend(["housing block number", "flat number", "level number", "floor number", "block number"])
177
+
178
+ # Detect boxes
179
+ boxes_gd = detect_gdino(image_pil, processor, gdino_model, 0.25, 0.20, queries)
180
+
181
+ # Detect OCR boxes if OCR is enabled
182
+ boxes_ocr = detect_ocr_boxes(img_bgr, ocr_model) if 'Text' in redaction_targets and ocr_model else np.empty((0, 4), dtype=int)
183
+
184
+ # Create a union mask
185
+ mask = union_masks(img_bgr.shape, [boxes_gd, boxes_ocr])
186
+
187
+ # Redact the image
188
+ redacted_image = redact(img_bgr, mask, method=redaction_method)
189
+
190
+ # Run GeoCLIP prediction
191
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
192
+ image_pil.save(tmp.name)
193
+ tmp_path = tmp.name
194
+
195
+ try:
196
+ top_pred_gps, top_pred_prob = geo_model.predict(tmp_path, top_k=1)
197
+
198
+ gps = [round(item, 3) for item in top_pred_gps.tolist()[0]]
199
+ prob = round(top_pred_prob.tolist()[0] * 100, 3)
200
+ finally:
201
+ os.unlink(tmp_path)
202
+
203
+ # Convert BGR to RGB for Gradio display and save to a temporary file
204
+ redacted_image_rgb = cv2.cvtColor(redacted_image, cv2.COLOR_BGR2RGB)
205
+ temp_img_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
206
+ Image.fromarray(redacted_image_rgb).save(temp_img_path)
207
+
208
+ # Create the text output
209
+ num_gd_boxes = len(boxes_gd)
210
+ num_ocr_boxes = len(boxes_ocr)
211
+ total_boxes = num_gd_boxes + num_ocr_boxes
212
+
213
+ result_text = f"Redaction Complete! 🎯\n\nDetected and redacted {total_boxes} items.\n"
214
+ if num_gd_boxes > 0:
215
+ result_text += f" - {num_gd_boxes} item(s) detected by Grounding DINO.\n"
216
+ if num_ocr_boxes > 0:
217
+ result_text += f" - {num_ocr_boxes} item(s) detected by OCR.\n"
218
+
219
+ result_text += f"\n--- Approximate GPS Prediction ---\n"
220
+ result_text += f"Predicted GPS: Latitude {gps[0]}, Longitude {gps[1]}\n"
221
+ result_text += f"Confidence: {prob}%\n"
222
+
223
+ return temp_img_path, result_text
224
+
225
+ # Define Gradio Interface
226
+ with gr.Blocks() as demo:
227
+ gr.Markdown("# Image Redaction and Geolocation Tool 🌍")
228
+ gr.Markdown(
229
+ "Upload an image and select the categories you wish to redact. The tool will "
230
+ "automatically detect and obscure the selected items using a blur or pixelate effect. "
231
+ "It will also provide a privacy-preserving approximate GPS location prediction using GeoCLIP."
232
+ )
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ image_input = gr.Image(type="pil", label="Upload Image")
237
+ redaction_targets = gr.CheckboxGroup(
238
+ choices=["Flags", "Signs", "Faces", "Building/Flat Numbers", "Text"],
239
+ label="Select Redaction Targets"
240
+ )
241
+ redaction_method = gr.Radio(
242
+ choices=["blur", "pixelate"],
243
+ label="Redaction Method",
244
+ value="blur"
245
+ )
246
+ process_button = gr.Button("Redact & Predict")
247
+
248
+ with gr.Column():
249
+ image_output = gr.Image(label="Redacted Image") # Changed from gr.Image to gr.File
250
+ result_output = gr.Textbox(label="Results", interactive=False)
251
+
252
+ process_button.click(
253
+ fn=process_image,
254
+ inputs=[image_input, redaction_targets, redaction_method],
255
+ outputs=[image_output, result_output]
256
+ )
257
+
258
+ gr.Examples(
259
+ examples=[
260
+ ["images/image2.png", ["Flags"], "blur"],
261
+ ["images/image1.png", ["Signs"], "pixelate"]
262
+ ],
263
+ inputs=[image_input, redaction_targets, redaction_method],
264
+ outputs=[image_output, result_output],
265
+ fn=process_image,
266
+ cache_examples=False
267
+ )
268
+
269
+ demo.launch()
images/image1.png ADDED

Git LFS Details

  • SHA256: 33bf1d2b69effb6104736b1c2bdbd99d88e899412b8d7b56c977136dadb1d93d
  • Pointer size: 131 Bytes
  • Size of remote file: 979 kB
images/image2.png ADDED

Git LFS Details

  • SHA256: f15cb52cf786309b5bbc1bf6f94283362a741e6d56ceab865d433b8651bfc05a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
images/image3.png ADDED

Git LFS Details

  • SHA256: 1564879f7f1667f14639858140caab4b42b641a3a9004d6df658a8a75ea89eae
  • Pointer size: 131 Bytes
  • Size of remote file: 777 kB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ opencv-python
5
+ numpy
6
+ Pillow
7
+ paddleocr
8
+ geoclip