Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import io | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import sys | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| from geoclip import GeoCLIP | |
| import tempfile | |
| import os | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Global model variables | |
| processor, gdino_model, ocr_model, geo_model = None, None, None, None | |
| def load_image(image_pil): | |
| """ | |
| Converts a PIL image to a BGR NumPy array. | |
| """ | |
| img_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| if img_bgr is None: | |
| raise ValueError("Could not decode image.") | |
| return img_bgr | |
| def load_gdino(): | |
| """ | |
| Loads and returns the Grounding DINO model and processor. | |
| """ | |
| global processor, gdino_model | |
| if gdino_model is None: | |
| print("Loading Grounding DINO model...") | |
| model_id = "IDEA-Research/grounding-dino-base" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) | |
| print("Grounding DINO model loaded.") | |
| return processor, gdino_model | |
| def load_geoclip(): | |
| """ | |
| Loads and returns the GeoCLIP model. | |
| """ | |
| global geo_model | |
| if geo_model is None: | |
| print("Loading GeoCLIP model...") | |
| geo_model = GeoCLIP() | |
| print("GeoCLIP model loaded.") | |
| return geo_model | |
| def detect_gdino(img_pil, processor, model, box_threshold, text_threshold, queries): | |
| """ | |
| Performs object detection using Grounding DINO. | |
| """ | |
| if not queries: | |
| return np.empty((0, 4), dtype=int) | |
| text = ". ".join([q.lower() for q in queries]) + "." | |
| inputs = processor(images=img_pil, text=text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_sizes=[img_pil.size[::-1]] | |
| ) | |
| boxes = results[0]["boxes"].cpu().numpy() | |
| return boxes | |
| def try_ocr(): | |
| """ | |
| Attempts to load PaddleOCR. Returns the model or None if it fails. | |
| """ | |
| global ocr_model | |
| if ocr_model is None: | |
| try: | |
| from paddleocr import PaddleOCR | |
| print("Loading PaddleOCR...") | |
| ocr_model = PaddleOCR(use_angle_cls=True, lang="en", show_log=False) | |
| print("PaddleOCR loaded.") | |
| except ImportError: | |
| print("PaddleOCR not found. Skipping OCR detection.") | |
| except Exception as e: | |
| print(f"Error loading PaddleOCR: {e}. Skipping OCR detection.") | |
| return ocr_model | |
| def detect_ocr_boxes(image_bgr, ocr): | |
| """ | |
| Detects text bounding boxes using PaddleOCR. | |
| """ | |
| results = ocr.ocr(image_bgr, cls=True) | |
| boxes = [] | |
| if results and results[0]: | |
| for line in results[0]: | |
| points = line[0] | |
| if points: | |
| x_coords = [p[0] for p in points] | |
| y_coords = [p[1] for p in points] | |
| x_min, x_max = min(x_coords), max(x_coords) | |
| y_min, y_max = min(y_coords), max(y_coords) | |
| boxes.append([x_min, y_min, x_max, y_max]) | |
| return np.array(boxes) | |
| def union_masks(image_shape, box_lists): | |
| """ | |
| Creates a single mask from a list of bounding box arrays. | |
| """ | |
| mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8) | |
| for boxes in box_lists: | |
| if boxes is not None and len(boxes) > 0: | |
| for box in boxes: | |
| x_min, y_min, x_max, y_max = [int(v) for v in box] | |
| mask[y_min:y_max, x_min:x_max] = 255 | |
| return mask | |
| def redact(image, mask, method="blur", blur_ksize=151, mosaic_scale=0.06): | |
| """ | |
| Applies the chosen redaction method (blur or pixelate) to the image based on the mask. | |
| """ | |
| if method == "blur": | |
| if blur_ksize % 2 == 0: | |
| blur_ksize += 1 | |
| blurred = cv2.GaussianBlur(image, (blur_ksize, blur_ksize), 0) | |
| return np.where(mask[:, :, None] == 255, blurred, image) | |
| elif method == "pixelate": | |
| h, w = image.shape[:2] | |
| small_h = int(h * mosaic_scale) | |
| small_w = int(w * mosaic_scale) | |
| if small_h <= 0: small_h = 1 | |
| if small_w <= 0: small_w = 1 | |
| resized = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_LINEAR) | |
| pixelated = cv2.resize(resized, (w, h), interpolation=cv2.INTER_NEAREST) | |
| return np.where(mask[:, :, None] == 255, pixelated, image) | |
| return image | |
| # Gradio processing function | |
| def process_image(image_pil, redaction_targets, redaction_method): | |
| """ | |
| Main function for the Gradio interface. | |
| Args: | |
| image_pil (PIL.Image): The input image. | |
| redaction_targets (list): A list of strings representing the items to redact. | |
| redaction_method (str): The method to use for redaction ('blur' or 'pixelate'). | |
| Returns: | |
| tuple: A tuple containing the path to the redacted image file and a text string with detection results. | |
| """ | |
| # Load models | |
| processor, gdino_model = load_gdino() | |
| ocr_model = try_ocr() | |
| geo_model = load_geoclip() | |
| if image_pil is None: | |
| return None, "Please upload an image." | |
| img_bgr = load_image(image_pil) | |
| # Define queries based on checkboxes | |
| queries = [] | |
| if "Flags" in redaction_targets: | |
| queries.extend(["flag", "country flags", "state flags"]) | |
| if "Signs" in redaction_targets: | |
| queries.extend(["street name sign", "road name sign"]) | |
| if 'Faces' in redaction_targets: | |
| queries.extend(["human faces", "faces", "people faces", "child faces", "human head", "people head"]) | |
| if 'Building/Flat Numbers' in redaction_targets: | |
| queries.extend(["housing block number", "flat number", "level number", "floor number", "block number"]) | |
| # Detect boxes | |
| boxes_gd = detect_gdino(image_pil, processor, gdino_model, 0.25, 0.20, queries) | |
| # Detect OCR boxes if OCR is enabled | |
| boxes_ocr = detect_ocr_boxes(img_bgr, ocr_model) if 'Text' in redaction_targets and ocr_model else np.empty((0, 4), dtype=int) | |
| # Create a union mask | |
| mask = union_masks(img_bgr.shape, [boxes_gd, boxes_ocr]) | |
| # Redact the image | |
| redacted_image = redact(img_bgr, mask, method=redaction_method) | |
| # Run GeoCLIP prediction | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| image_pil.save(tmp.name) | |
| tmp_path = tmp.name | |
| try: | |
| top_pred_gps, top_pred_prob = geo_model.predict(tmp_path, top_k=1) | |
| gps = [round(item, 3) for item in top_pred_gps.tolist()[0]] | |
| prob = round(top_pred_prob.tolist()[0] * 100, 3) | |
| finally: | |
| os.unlink(tmp_path) | |
| # Convert BGR to RGB for Gradio display and save to a temporary file | |
| redacted_image_rgb = cv2.cvtColor(redacted_image, cv2.COLOR_BGR2RGB) | |
| temp_img_path = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg").name | |
| Image.fromarray(redacted_image_rgb).save(temp_img_path) | |
| # Create the text output | |
| num_gd_boxes = len(boxes_gd) | |
| num_ocr_boxes = len(boxes_ocr) | |
| total_boxes = num_gd_boxes + num_ocr_boxes | |
| result_text = f"Redaction Complete! π―\n\nDetected and redacted {total_boxes} items.\n" | |
| if num_gd_boxes > 0: | |
| result_text += f" - {num_gd_boxes} item(s) detected by Grounding DINO.\n" | |
| if num_ocr_boxes > 0: | |
| result_text += f" - {num_ocr_boxes} item(s) detected by OCR.\n" | |
| result_text += f"\n--- Approximate GPS Prediction ---\n" | |
| result_text += f"Predicted GPS: Latitude {gps[0]}, Longitude {gps[1]}\n" | |
| result_text += f"Confidence: {prob}%\n" | |
| return temp_img_path, result_text | |
| # Define Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Redaction and Geolocation Tool π") | |
| gr.Markdown( | |
| "Upload an image and select the categories you wish to redact. The tool will " | |
| "automatically detect and obscure the selected items using a blur or pixelate effect. " | |
| "It will also provide a privacy-preserving approximate GPS location prediction using GeoCLIP." | |
| "Developed for TikTok TechJam 2025, Privacy x AI where the goal was to build an app that can auto blur or filter sensitive location information" | |
| "This space is running on free CPU tier so expect performance to be slow ~1.5min per image!" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| redaction_targets = gr.CheckboxGroup( | |
| choices=["Flags", "Signs", "Faces", "Building/Flat Numbers", "Text"], | |
| label="Select Redaction Targets" | |
| ) | |
| redaction_method = gr.Radio( | |
| choices=["blur", "pixelate"], | |
| label="Redaction Method", | |
| value="blur" | |
| ) | |
| process_button = gr.Button("Redact & Predict") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Redacted Image") # Changed from gr.Image to gr.File | |
| result_output = gr.Textbox(label="Results", interactive=False) | |
| process_button.click( | |
| fn=process_image, | |
| inputs=[image_input, redaction_targets, redaction_method], | |
| outputs=[image_output, result_output] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["images/image2.png", ["Flags"], "blur"], | |
| ["images/image1.png", ["Signs"], "pixelate"] | |
| ], | |
| inputs=[image_input, redaction_targets, redaction_method], | |
| outputs=[image_output, result_output], | |
| fn=process_image, | |
| cache_examples=False | |
| ) | |
| demo.launch() |