Spaces:
Sleeping
Sleeping
File size: 9,804 Bytes
c7935e1 389e1e9 c7935e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
import gradio as gr
import io
import cv2
import numpy as np
import torch
from PIL import Image
import sys
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from geoclip import GeoCLIP
import tempfile
import os
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global model variables
processor, gdino_model, ocr_model, geo_model = None, None, None, None
def load_image(image_pil):
"""
Converts a PIL image to a BGR NumPy array.
"""
img_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
if img_bgr is None:
raise ValueError("Could not decode image.")
return img_bgr
def load_gdino():
"""
Loads and returns the Grounding DINO model and processor.
"""
global processor, gdino_model
if gdino_model is None:
print("Loading Grounding DINO model...")
model_id = "IDEA-Research/grounding-dino-base"
processor = AutoProcessor.from_pretrained(model_id)
gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
print("Grounding DINO model loaded.")
return processor, gdino_model
def load_geoclip():
"""
Loads and returns the GeoCLIP model.
"""
global geo_model
if geo_model is None:
print("Loading GeoCLIP model...")
geo_model = GeoCLIP()
print("GeoCLIP model loaded.")
return geo_model
def detect_gdino(img_pil, processor, model, box_threshold, text_threshold, queries):
"""
Performs object detection using Grounding DINO.
"""
if not queries:
return np.empty((0, 4), dtype=int)
text = ". ".join([q.lower() for q in queries]) + "."
inputs = processor(images=img_pil, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img_pil.size[::-1]]
)
boxes = results[0]["boxes"].cpu().numpy()
return boxes
def try_ocr():
"""
Attempts to load PaddleOCR. Returns the model or None if it fails.
"""
global ocr_model
if ocr_model is None:
try:
from paddleocr import PaddleOCR
print("Loading PaddleOCR...")
ocr_model = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
print("PaddleOCR loaded.")
except ImportError:
print("PaddleOCR not found. Skipping OCR detection.")
except Exception as e:
print(f"Error loading PaddleOCR: {e}. Skipping OCR detection.")
return ocr_model
def detect_ocr_boxes(image_bgr, ocr):
"""
Detects text bounding boxes using PaddleOCR.
"""
results = ocr.ocr(image_bgr, cls=True)
boxes = []
if results and results[0]:
for line in results[0]:
points = line[0]
if points:
x_coords = [p[0] for p in points]
y_coords = [p[1] for p in points]
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
boxes.append([x_min, y_min, x_max, y_max])
return np.array(boxes)
def union_masks(image_shape, box_lists):
"""
Creates a single mask from a list of bounding box arrays.
"""
mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
for boxes in box_lists:
if boxes is not None and len(boxes) > 0:
for box in boxes:
x_min, y_min, x_max, y_max = [int(v) for v in box]
mask[y_min:y_max, x_min:x_max] = 255
return mask
def redact(image, mask, method="blur", blur_ksize=151, mosaic_scale=0.06):
"""
Applies the chosen redaction method (blur or pixelate) to the image based on the mask.
"""
if method == "blur":
if blur_ksize % 2 == 0:
blur_ksize += 1
blurred = cv2.GaussianBlur(image, (blur_ksize, blur_ksize), 0)
return np.where(mask[:, :, None] == 255, blurred, image)
elif method == "pixelate":
h, w = image.shape[:2]
small_h = int(h * mosaic_scale)
small_w = int(w * mosaic_scale)
if small_h <= 0: small_h = 1
if small_w <= 0: small_w = 1
resized = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
pixelated = cv2.resize(resized, (w, h), interpolation=cv2.INTER_NEAREST)
return np.where(mask[:, :, None] == 255, pixelated, image)
return image
# Gradio processing function
def process_image(image_pil, redaction_targets, redaction_method):
"""
Main function for the Gradio interface.
Args:
image_pil (PIL.Image): The input image.
redaction_targets (list): A list of strings representing the items to redact.
redaction_method (str): The method to use for redaction ('blur' or 'pixelate').
Returns:
tuple: A tuple containing the path to the redacted image file and a text string with detection results.
"""
# Load models
processor, gdino_model = load_gdino()
ocr_model = try_ocr()
geo_model = load_geoclip()
if image_pil is None:
return None, "Please upload an image."
img_bgr = load_image(image_pil)
# Define queries based on checkboxes
queries = []
if "Flags" in redaction_targets:
queries.extend(["flag", "country flags", "state flags"])
if "Signs" in redaction_targets:
queries.extend(["street name sign", "road name sign"])
if 'Faces' in redaction_targets:
queries.extend(["human faces", "faces", "people faces", "child faces", "human head", "people head"])
if 'Building/Flat Numbers' in redaction_targets:
queries.extend(["housing block number", "flat number", "level number", "floor number", "block number"])
# Detect boxes
boxes_gd = detect_gdino(image_pil, processor, gdino_model, 0.25, 0.20, queries)
# Detect OCR boxes if OCR is enabled
boxes_ocr = detect_ocr_boxes(img_bgr, ocr_model) if 'Text' in redaction_targets and ocr_model else np.empty((0, 4), dtype=int)
# Create a union mask
mask = union_masks(img_bgr.shape, [boxes_gd, boxes_ocr])
# Redact the image
redacted_image = redact(img_bgr, mask, method=redaction_method)
# Run GeoCLIP prediction
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
image_pil.save(tmp.name)
tmp_path = tmp.name
try:
top_pred_gps, top_pred_prob = geo_model.predict(tmp_path, top_k=1)
gps = [round(item, 3) for item in top_pred_gps.tolist()[0]]
prob = round(top_pred_prob.tolist()[0] * 100, 3)
finally:
os.unlink(tmp_path)
# Convert BGR to RGB for Gradio display and save to a temporary file
redacted_image_rgb = cv2.cvtColor(redacted_image, cv2.COLOR_BGR2RGB)
temp_img_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
Image.fromarray(redacted_image_rgb).save(temp_img_path)
# Create the text output
num_gd_boxes = len(boxes_gd)
num_ocr_boxes = len(boxes_ocr)
total_boxes = num_gd_boxes + num_ocr_boxes
result_text = f"Redaction Complete! π―\n\nDetected and redacted {total_boxes} items.\n"
if num_gd_boxes > 0:
result_text += f" - {num_gd_boxes} item(s) detected by Grounding DINO.\n"
if num_ocr_boxes > 0:
result_text += f" - {num_ocr_boxes} item(s) detected by OCR.\n"
result_text += f"\n--- Approximate GPS Prediction ---\n"
result_text += f"Predicted GPS: Latitude {gps[0]}, Longitude {gps[1]}\n"
result_text += f"Confidence: {prob}%\n"
return temp_img_path, result_text
# Define Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Image Redaction and Geolocation Tool π")
gr.Markdown(
"Upload an image and select the categories you wish to redact. The tool will "
"automatically detect and obscure the selected items using a blur or pixelate effect. "
"It will also provide a privacy-preserving approximate GPS location prediction using GeoCLIP."
"Developed for TikTok TechJam 2025, Privacy x AI where the goal was to build an app that can auto blur or filter sensitive location information"
"This space is running on free CPU tier so expect performance to be slow ~1.5min per image!"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
redaction_targets = gr.CheckboxGroup(
choices=["Flags", "Signs", "Faces", "Building/Flat Numbers", "Text"],
label="Select Redaction Targets"
)
redaction_method = gr.Radio(
choices=["blur", "pixelate"],
label="Redaction Method",
value="blur"
)
process_button = gr.Button("Redact & Predict")
with gr.Column():
image_output = gr.Image(label="Redacted Image") # Changed from gr.Image to gr.File
result_output = gr.Textbox(label="Results", interactive=False)
process_button.click(
fn=process_image,
inputs=[image_input, redaction_targets, redaction_method],
outputs=[image_output, result_output]
)
gr.Examples(
examples=[
["images/image2.png", ["Flags"], "blur"],
["images/image1.png", ["Signs"], "pixelate"]
],
inputs=[image_input, redaction_targets, redaction_method],
outputs=[image_output, result_output],
fn=process_image,
cache_examples=False
)
demo.launch() |