Geo-Safely / app.py
awpbash's picture
Update app.py
389e1e9 verified
import gradio as gr
import io
import cv2
import numpy as np
import torch
from PIL import Image
import sys
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from geoclip import GeoCLIP
import tempfile
import os
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global model variables
processor, gdino_model, ocr_model, geo_model = None, None, None, None
def load_image(image_pil):
"""
Converts a PIL image to a BGR NumPy array.
"""
img_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
if img_bgr is None:
raise ValueError("Could not decode image.")
return img_bgr
def load_gdino():
"""
Loads and returns the Grounding DINO model and processor.
"""
global processor, gdino_model
if gdino_model is None:
print("Loading Grounding DINO model...")
model_id = "IDEA-Research/grounding-dino-base"
processor = AutoProcessor.from_pretrained(model_id)
gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
print("Grounding DINO model loaded.")
return processor, gdino_model
def load_geoclip():
"""
Loads and returns the GeoCLIP model.
"""
global geo_model
if geo_model is None:
print("Loading GeoCLIP model...")
geo_model = GeoCLIP()
print("GeoCLIP model loaded.")
return geo_model
def detect_gdino(img_pil, processor, model, box_threshold, text_threshold, queries):
"""
Performs object detection using Grounding DINO.
"""
if not queries:
return np.empty((0, 4), dtype=int)
text = ". ".join([q.lower() for q in queries]) + "."
inputs = processor(images=img_pil, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img_pil.size[::-1]]
)
boxes = results[0]["boxes"].cpu().numpy()
return boxes
def try_ocr():
"""
Attempts to load PaddleOCR. Returns the model or None if it fails.
"""
global ocr_model
if ocr_model is None:
try:
from paddleocr import PaddleOCR
print("Loading PaddleOCR...")
ocr_model = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
print("PaddleOCR loaded.")
except ImportError:
print("PaddleOCR not found. Skipping OCR detection.")
except Exception as e:
print(f"Error loading PaddleOCR: {e}. Skipping OCR detection.")
return ocr_model
def detect_ocr_boxes(image_bgr, ocr):
"""
Detects text bounding boxes using PaddleOCR.
"""
results = ocr.ocr(image_bgr, cls=True)
boxes = []
if results and results[0]:
for line in results[0]:
points = line[0]
if points:
x_coords = [p[0] for p in points]
y_coords = [p[1] for p in points]
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
boxes.append([x_min, y_min, x_max, y_max])
return np.array(boxes)
def union_masks(image_shape, box_lists):
"""
Creates a single mask from a list of bounding box arrays.
"""
mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
for boxes in box_lists:
if boxes is not None and len(boxes) > 0:
for box in boxes:
x_min, y_min, x_max, y_max = [int(v) for v in box]
mask[y_min:y_max, x_min:x_max] = 255
return mask
def redact(image, mask, method="blur", blur_ksize=151, mosaic_scale=0.06):
"""
Applies the chosen redaction method (blur or pixelate) to the image based on the mask.
"""
if method == "blur":
if blur_ksize % 2 == 0:
blur_ksize += 1
blurred = cv2.GaussianBlur(image, (blur_ksize, blur_ksize), 0)
return np.where(mask[:, :, None] == 255, blurred, image)
elif method == "pixelate":
h, w = image.shape[:2]
small_h = int(h * mosaic_scale)
small_w = int(w * mosaic_scale)
if small_h <= 0: small_h = 1
if small_w <= 0: small_w = 1
resized = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
pixelated = cv2.resize(resized, (w, h), interpolation=cv2.INTER_NEAREST)
return np.where(mask[:, :, None] == 255, pixelated, image)
return image
# Gradio processing function
def process_image(image_pil, redaction_targets, redaction_method):
"""
Main function for the Gradio interface.
Args:
image_pil (PIL.Image): The input image.
redaction_targets (list): A list of strings representing the items to redact.
redaction_method (str): The method to use for redaction ('blur' or 'pixelate').
Returns:
tuple: A tuple containing the path to the redacted image file and a text string with detection results.
"""
# Load models
processor, gdino_model = load_gdino()
ocr_model = try_ocr()
geo_model = load_geoclip()
if image_pil is None:
return None, "Please upload an image."
img_bgr = load_image(image_pil)
# Define queries based on checkboxes
queries = []
if "Flags" in redaction_targets:
queries.extend(["flag", "country flags", "state flags"])
if "Signs" in redaction_targets:
queries.extend(["street name sign", "road name sign"])
if 'Faces' in redaction_targets:
queries.extend(["human faces", "faces", "people faces", "child faces", "human head", "people head"])
if 'Building/Flat Numbers' in redaction_targets:
queries.extend(["housing block number", "flat number", "level number", "floor number", "block number"])
# Detect boxes
boxes_gd = detect_gdino(image_pil, processor, gdino_model, 0.25, 0.20, queries)
# Detect OCR boxes if OCR is enabled
boxes_ocr = detect_ocr_boxes(img_bgr, ocr_model) if 'Text' in redaction_targets and ocr_model else np.empty((0, 4), dtype=int)
# Create a union mask
mask = union_masks(img_bgr.shape, [boxes_gd, boxes_ocr])
# Redact the image
redacted_image = redact(img_bgr, mask, method=redaction_method)
# Run GeoCLIP prediction
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
image_pil.save(tmp.name)
tmp_path = tmp.name
try:
top_pred_gps, top_pred_prob = geo_model.predict(tmp_path, top_k=1)
gps = [round(item, 3) for item in top_pred_gps.tolist()[0]]
prob = round(top_pred_prob.tolist()[0] * 100, 3)
finally:
os.unlink(tmp_path)
# Convert BGR to RGB for Gradio display and save to a temporary file
redacted_image_rgb = cv2.cvtColor(redacted_image, cv2.COLOR_BGR2RGB)
temp_img_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
Image.fromarray(redacted_image_rgb).save(temp_img_path)
# Create the text output
num_gd_boxes = len(boxes_gd)
num_ocr_boxes = len(boxes_ocr)
total_boxes = num_gd_boxes + num_ocr_boxes
result_text = f"Redaction Complete! 🎯\n\nDetected and redacted {total_boxes} items.\n"
if num_gd_boxes > 0:
result_text += f" - {num_gd_boxes} item(s) detected by Grounding DINO.\n"
if num_ocr_boxes > 0:
result_text += f" - {num_ocr_boxes} item(s) detected by OCR.\n"
result_text += f"\n--- Approximate GPS Prediction ---\n"
result_text += f"Predicted GPS: Latitude {gps[0]}, Longitude {gps[1]}\n"
result_text += f"Confidence: {prob}%\n"
return temp_img_path, result_text
# Define Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Image Redaction and Geolocation Tool 🌍")
gr.Markdown(
"Upload an image and select the categories you wish to redact. The tool will "
"automatically detect and obscure the selected items using a blur or pixelate effect. "
"It will also provide a privacy-preserving approximate GPS location prediction using GeoCLIP."
"Developed for TikTok TechJam 2025, Privacy x AI where the goal was to build an app that can auto blur or filter sensitive location information"
"This space is running on free CPU tier so expect performance to be slow ~1.5min per image!"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
redaction_targets = gr.CheckboxGroup(
choices=["Flags", "Signs", "Faces", "Building/Flat Numbers", "Text"],
label="Select Redaction Targets"
)
redaction_method = gr.Radio(
choices=["blur", "pixelate"],
label="Redaction Method",
value="blur"
)
process_button = gr.Button("Redact & Predict")
with gr.Column():
image_output = gr.Image(label="Redacted Image") # Changed from gr.Image to gr.File
result_output = gr.Textbox(label="Results", interactive=False)
process_button.click(
fn=process_image,
inputs=[image_input, redaction_targets, redaction_method],
outputs=[image_output, result_output]
)
gr.Examples(
examples=[
["images/image2.png", ["Flags"], "blur"],
["images/image1.png", ["Signs"], "pixelate"]
],
inputs=[image_input, redaction_targets, redaction_method],
outputs=[image_output, result_output],
fn=process_image,
cache_examples=False
)
demo.launch()