import gradio as gr from transformers import AutoImageProcessor, AutoModelForObjectDetection import torch from PIL import Image, ImageDraw # Load the model and processor processor = AutoImageProcessor.from_pretrained("0llheaven/Conditional-detr-finetuned-tf") model = AutoModelForObjectDetection.from_pretrained("0llheaven/Conditional-detr-finetuned-tf") def detect_objects(image, score_threshold): # Convert image to RGB if it's grayscale if image.mode != "RGB": image = image.convert("RGB") # Prepare input for the model inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # Filter predictions based on the user-defined score threshold target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes) labels_output = [] # Draw bounding boxes around detected objects draw = ImageDraw.Draw(image) for result in results: scores = result["scores"] labels = result["labels"] boxes = result["boxes"] for score, label, box in zip(scores, labels, boxes): if score >= score_threshold: # Only draw if score is above threshold box = [round(i, 2) for i in box.tolist()] label_name = "Pneumonia" if label.item() == 0 else "No detection" draw.rectangle(box, outline="red", width=3) draw.text((box[0], box[1]), f"{label_name}: {round(score.item(), 3)}", fill="red") labels_output.append(f"{label_name}: {round(score.item(), 3)}") # If no objects detected, append "No detection" if not labels_output: labels_output.append("No detection") return image, "\n".join(labels_output) # Create the Gradio interface interface = gr.Interface( fn=detect_objects, inputs=[gr.Image(type="pil"), gr.Slider(0, 1, value=0.5, label="Score Threshold")], # Add slider for score threshold # outputs=gr.Image(type="pil"), # Corrected output type outputs=[gr.Image(type="pil"), gr.Textbox(label="Detected Objects")], title="Object Detection with Transformers", description="Upload an image to detect objects using a fine-tuned Conditional-DETR model." ) # Launch the interface interface.launch()