File size: 9,804 Bytes
c7935e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389e1e9
 
c7935e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import gradio as gr
import io
import cv2
import numpy as np
import torch
from PIL import Image
import sys
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from geoclip import GeoCLIP
import tempfile
import os

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Global model variables
processor, gdino_model, ocr_model, geo_model = None, None, None, None

def load_image(image_pil):
    """
    Converts a PIL image to a BGR NumPy array.
    """
    img_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
    if img_bgr is None:
        raise ValueError("Could not decode image.")
    return img_bgr

def load_gdino():
    """
    Loads and returns the Grounding DINO model and processor.
    """
    global processor, gdino_model
    if gdino_model is None:
        print("Loading Grounding DINO model...")
        model_id = "IDEA-Research/grounding-dino-base"
        processor = AutoProcessor.from_pretrained(model_id)
        gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
        print("Grounding DINO model loaded.")
    return processor, gdino_model

def load_geoclip():
    """
    Loads and returns the GeoCLIP model.
    """
    global geo_model
    if geo_model is None:
        print("Loading GeoCLIP model...")
        geo_model = GeoCLIP()
        print("GeoCLIP model loaded.")
    return geo_model

def detect_gdino(img_pil, processor, model, box_threshold, text_threshold, queries):
    """
    Performs object detection using Grounding DINO.
    """
    if not queries:
        return np.empty((0, 4), dtype=int)
    
    text = ". ".join([q.lower() for q in queries]) + "."
    inputs = processor(images=img_pil, text=text, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        target_sizes=[img_pil.size[::-1]]
    )
    
    boxes = results[0]["boxes"].cpu().numpy()
    return boxes

def try_ocr():
    """
    Attempts to load PaddleOCR. Returns the model or None if it fails.
    """
    global ocr_model
    if ocr_model is None:
        try:
            from paddleocr import PaddleOCR
            print("Loading PaddleOCR...")
            ocr_model = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
            print("PaddleOCR loaded.")
        except ImportError:
            print("PaddleOCR not found. Skipping OCR detection.")
        except Exception as e:
            print(f"Error loading PaddleOCR: {e}. Skipping OCR detection.")
    return ocr_model

def detect_ocr_boxes(image_bgr, ocr):
    """
    Detects text bounding boxes using PaddleOCR.
    """
    results = ocr.ocr(image_bgr, cls=True)
    boxes = []
    if results and results[0]:
        for line in results[0]:
            points = line[0]
            if points:
                x_coords = [p[0] for p in points]
                y_coords = [p[1] for p in points]
                x_min, x_max = min(x_coords), max(x_coords)
                y_min, y_max = min(y_coords), max(y_coords)
                boxes.append([x_min, y_min, x_max, y_max])
    return np.array(boxes)

def union_masks(image_shape, box_lists):
    """
    Creates a single mask from a list of bounding box arrays.
    """
    mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
    for boxes in box_lists:
        if boxes is not None and len(boxes) > 0:
            for box in boxes:
                x_min, y_min, x_max, y_max = [int(v) for v in box]
                mask[y_min:y_max, x_min:x_max] = 255
    return mask

def redact(image, mask, method="blur", blur_ksize=151, mosaic_scale=0.06):
    """
    Applies the chosen redaction method (blur or pixelate) to the image based on the mask.
    """
    if method == "blur":
        if blur_ksize % 2 == 0:
            blur_ksize += 1
        blurred = cv2.GaussianBlur(image, (blur_ksize, blur_ksize), 0)
        return np.where(mask[:, :, None] == 255, blurred, image)
    elif method == "pixelate":
        h, w = image.shape[:2]
        small_h = int(h * mosaic_scale)
        small_w = int(w * mosaic_scale)
        if small_h <= 0: small_h = 1
        if small_w <= 0: small_w = 1
        
        resized = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
        pixelated = cv2.resize(resized, (w, h), interpolation=cv2.INTER_NEAREST)
        return np.where(mask[:, :, None] == 255, pixelated, image)
    return image

# Gradio processing function
def process_image(image_pil, redaction_targets, redaction_method):
    """
    Main function for the Gradio interface.
    
    Args:
        image_pil (PIL.Image): The input image.
        redaction_targets (list): A list of strings representing the items to redact.
        redaction_method (str): The method to use for redaction ('blur' or 'pixelate').
    
    Returns:
        tuple: A tuple containing the path to the redacted image file and a text string with detection results.
    """
    
    # Load models
    processor, gdino_model = load_gdino()
    ocr_model = try_ocr()
    geo_model = load_geoclip()

    if image_pil is None:
        return None, "Please upload an image."

    img_bgr = load_image(image_pil)
    
    # Define queries based on checkboxes
    queries = []
    if "Flags" in redaction_targets:
        queries.extend(["flag", "country flags", "state flags"])
    if "Signs" in redaction_targets:
        queries.extend(["street name sign", "road name sign"])
    if 'Faces' in redaction_targets:
        queries.extend(["human faces", "faces", "people faces", "child faces", "human head", "people head"])
    if 'Building/Flat Numbers' in redaction_targets:
        queries.extend(["housing block number", "flat number", "level number", "floor number", "block number"])

    # Detect boxes
    boxes_gd = detect_gdino(image_pil, processor, gdino_model, 0.25, 0.20, queries)
    
    # Detect OCR boxes if OCR is enabled
    boxes_ocr = detect_ocr_boxes(img_bgr, ocr_model) if 'Text' in redaction_targets and ocr_model else np.empty((0, 4), dtype=int)
    
    # Create a union mask
    mask = union_masks(img_bgr.shape, [boxes_gd, boxes_ocr])
    
    # Redact the image
    redacted_image = redact(img_bgr, mask, method=redaction_method)
    
    # Run GeoCLIP prediction
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
        image_pil.save(tmp.name)
        tmp_path = tmp.name
    
    try:
        top_pred_gps, top_pred_prob = geo_model.predict(tmp_path, top_k=1)
        
        gps = [round(item, 3) for item in top_pred_gps.tolist()[0]]
        prob = round(top_pred_prob.tolist()[0] * 100, 3)
    finally:
        os.unlink(tmp_path)

    # Convert BGR to RGB for Gradio display and save to a temporary file
    redacted_image_rgb = cv2.cvtColor(redacted_image, cv2.COLOR_BGR2RGB)
    temp_img_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name
    Image.fromarray(redacted_image_rgb).save(temp_img_path)
    
    # Create the text output
    num_gd_boxes = len(boxes_gd)
    num_ocr_boxes = len(boxes_ocr)
    total_boxes = num_gd_boxes + num_ocr_boxes
    
    result_text = f"Redaction Complete! 🎯\n\nDetected and redacted {total_boxes} items.\n"
    if num_gd_boxes > 0:
        result_text += f"  - {num_gd_boxes} item(s) detected by Grounding DINO.\n"
    if num_ocr_boxes > 0:
        result_text += f"  - {num_ocr_boxes} item(s) detected by OCR.\n"
    
    result_text += f"\n--- Approximate GPS Prediction ---\n"
    result_text += f"Predicted GPS: Latitude {gps[0]}, Longitude {gps[1]}\n"
    result_text += f"Confidence: {prob}%\n"

    return temp_img_path, result_text

# Define Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Redaction and Geolocation Tool 🌍")
    gr.Markdown(
        "Upload an image and select the categories you wish to redact. The tool will "
        "automatically detect and obscure the selected items using a blur or pixelate effect. "
        "It will also provide a privacy-preserving approximate GPS location prediction using GeoCLIP."
        "Developed for TikTok TechJam 2025, Privacy x AI where the goal was to build an app that can auto blur or filter sensitive location information"
        "This space is running on free CPU tier so expect performance to be slow ~1.5min per image!"
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            redaction_targets = gr.CheckboxGroup(
                choices=["Flags", "Signs", "Faces", "Building/Flat Numbers", "Text"],
                label="Select Redaction Targets"
            )
            redaction_method = gr.Radio(
                choices=["blur", "pixelate"],
                label="Redaction Method",
                value="blur"
            )
            process_button = gr.Button("Redact & Predict")
        
        with gr.Column():
            image_output = gr.Image(label="Redacted Image")  # Changed from gr.Image to gr.File
            result_output = gr.Textbox(label="Results", interactive=False)
            
    process_button.click(
        fn=process_image,
        inputs=[image_input, redaction_targets, redaction_method],
        outputs=[image_output, result_output]
    )

    gr.Examples(
        examples=[
            ["images/image2.png", ["Flags"], "blur"],
            ["images/image1.png", ["Signs"], "pixelate"]
        ],
        inputs=[image_input, redaction_targets, redaction_method],
        outputs=[image_output, result_output],
        fn=process_image,
        cache_examples=False
    )
    
demo.launch()