import streamlit as st from transformers import SamModel, SamProcessor, pipeline from PIL import Image, ImageOps import numpy as np import torch # Constants XS_YS = [(2.0, 2.0), (2.5, 2.5)] WIDTH = 600 # Load models @st.cache_resource def load_models(): model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") od_pipe = pipeline("object-detection", "facebook/detr-resnet-50") return model, processor, od_pipe def process_image(image, model, processor, bounding_box=None, input_point=None): try: # Convert image to RGB mode image = image.convert('RGB') # Convert image to numpy array image_array = np.array(image) if bounding_box: inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt") elif input_point: inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt") else: raise ValueError("Either bounding_box or input_point must be provided") with torch.no_grad(): outputs = model(**inputs) predicted_masks = processor.image_processor.post_process_masks( outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] ) return predicted_masks[0] except Exception as e: st.error(f"Error processing image: {str(e)}") return None def display_masked_images(raw_image, predicted_mask, caption_prefix): for i in range(3): mask = predicted_mask[0][i] int_mask = np.array(mask).astype(int) * 255 mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH) def main(): st.title("Image Segmentation with Object Detection") # Introduction and How-to st.markdown(""" Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works: - **Upload an image**: Drag and drop or use the browse files option. - **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes. - **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data. - **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis. Please note that processing takes some time. We appreciate your patience as the models do their work! """) # Model credits st.subheader("Powered by:") st.write("- Object Detection Model: `facebook/detr-resnet-50`") st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`") model, processor, od_pipe = load_models() uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: raw_image = Image.open(uploaded_file) st.subheader("Uploaded Image") st.image(raw_image, caption="Uploaded Image", width=WIDTH) with st.spinner('Processing image...'): # Object Detection pipeline_output = od_pipe(raw_image) input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output] labels_format = [b['label'] for b in pipeline_output] # Process bounding boxes for b, l in zip(input_boxes_format, labels_format): st.subheader(f'bounding box : {l}') predicted_mask = process_image(raw_image, model, processor, bounding_box=b) if predicted_mask is not None: display_masked_images(raw_image, predicted_mask, "Masked Image") # Process input points for (x, y) in XS_YS: point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y st.subheader(f"Input points : ({1/x},{1/y})") predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y]) if predicted_mask is not None: display_masked_images(raw_image, predicted_mask, "Masked Image") if __name__ == "__main__": main()