import torch import gradio as gr from transformers import Owlv2Processor, Owlv2ForObjectDetection import spaces import numpy as np from PIL import Image import io import random from transformers import SamModel, SamProcessor def apply_colored_masks_on_image(image, masks): if not isinstance(image, Image.Image): image = Image.fromarray(image.astype('uint8'), 'RGB') image_rgba = image.convert("RGBA") for i in range(masks.shape[0]): mask = masks[i].squeeze().cpu().numpy() mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L') color = tuple([random.randint(0, 255) for _ in range(3)] + [128]) colored_mask = Image.new("RGBA", image.size, color) colored_mask.putalpha(mask_image) image_rgba = Image.alpha_composite(image_rgba, colored_mask) return image_rgba # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge") @spaces.GPU def query_image(img, text_queries, score_threshold=0.5): text_queries = text_queries.split(",") size = max(img.shape[:2]) target_sizes = torch.Tensor([[size, size]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): model_outputs = model(**inputs) model_outputs.logits = model_outputs.logits.cpu() model_outputs.pred_boxes = model_outputs.pred_boxes.cpu() results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] img_pil = Image.fromarray(img.astype('uint8'), 'RGB') result_labels = [] result_boxes = [] for box, score, label in zip(boxes, scores, labels): if score >= score_threshold: box = [int(i) for i in box.tolist()] label_text = text_queries[label.item()] result_labels.append((box, label_text)) result_boxes.append(box) sam_image = generate_image_with_sam(np.array(img_pil), result_boxes) return sam_image,result_labels def generate_image_with_sam(img, input_boxes): img_pil = Image.fromarray(img.astype('uint8'), 'RGB') inputs = processor_sam(img_pil, return_tensors="pt").to(device) image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"]) inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device) inputs["input_boxes"].shape inputs.pop("pixel_values", None) inputs.update({"image_embeddings": image_embeddings}) with torch.no_grad(): outputs = model_sam(**inputs, multimask_output=False) masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) # scores = outputs.iou_scores SAM_image = apply_colored_masks_on_image(img_pil, masks[0]) return SAM_image description = """ Split anythings """ demo = gr.Interface( fn=query_image, inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")], outputs=gr.AnnotatedImage(), title="Zero-Shot Object Detection SV3", description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.", examples=[ ["images/dark_cell.png", "gray cells", 0.1], ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35], ], ) demo.launch()