Spaces:
Runtime error
Runtime error
import streamlit as st | |
# from transformers import AutoProcessor, AutoModelForMaskGeneration | |
from transformers import SamModel, SamProcessor | |
from transformers import pipeline | |
from PIL import Image, ImageOps | |
# from PIL import Image | |
import numpy as np | |
# import matplotlib.pyplot as plt | |
import torch | |
import requests | |
from io import BytesIO | |
def main(): | |
st.title("Image Segmentation with Object Detection") | |
# Introduction and How-to | |
st.markdown(""" | |
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works: | |
- **Upload an image**: Drag and drop or use the browse files option. | |
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes. | |
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data. | |
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis. | |
Please note that processing takes some time. We appreciate your patience as the models do their work! | |
""") | |
# Model credits | |
st.subheader("Powered by:") | |
st.write("- Object Detection Model: `facebook/detr-resnet-50`") | |
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`") | |
# Load SAM by Facebook | |
# processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge") | |
# model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-huge") | |
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
# Load Object Detection | |
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
xs_ys = [(2.0, 2.0), (2.5, 2.5)] #, (2.5, 2.0), (2.0, 2.5), (1.5, 1.5)] | |
alpha = 20 | |
width = 600 | |
if uploaded_file is not None: | |
raw_image = Image.open(uploaded_file) | |
st.subheader("Uploaded Image") | |
st.image(raw_image, caption="Uploaded Image", width=width) | |
### STEP 1. Object Detection | |
pipeline_output = od_pipe(raw_image) | |
# Convert the bounding boxes from the pipeline output into the expected format for the SAM processor | |
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output] | |
labels_format = [b['label'] for b in pipeline_output] | |
print(input_boxes_format) | |
print(labels_format) | |
# Now use these formatted boxes with the processor | |
for b, l in zip(input_boxes_format, labels_format): | |
with st.spinner('Processing...'): | |
st.subheader(f'bounding box : {l}') | |
inputs = processor(images=raw_image, | |
input_boxes=[b], | |
return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks, | |
inputs["original_sizes"], | |
inputs["reshaped_input_sizes"] | |
) | |
predicted_mask = predicted_masks[0] | |
for i in range(0, 3): | |
# 2D array (boolean mask) | |
mask = predicted_mask[0][i] | |
int_mask = np.array(mask).astype(int) * 255 | |
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') | |
# Apply the mask to the image | |
# Convert mask to a 3-channel image if your base image is in RGB | |
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255)) | |
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) | |
#display the final image | |
st.image(final_image, caption=f"Masked Image {i+1}", width=width) | |
### | |
for (x, y) in xs_ys: | |
with st.spinner('Processing...'): | |
# Calculate input points | |
point_x = raw_image.size[0] // x | |
point_y = raw_image.size[1] // y | |
input_points = [[[ point_x, point_y ]]] | |
# Prepare inputs | |
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") | |
# Generate masks | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Post-process masks | |
predicted_masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks, | |
inputs["original_sizes"], | |
inputs["reshaped_input_sizes"] | |
) | |
predicted_mask = predicted_masks[0] | |
# Display masked images | |
st.subheader(f"Input points : ({1/x},{1/y})") | |
for i in range(3): | |
mask = predicted_mask[0][i] | |
int_mask = np.array(mask).astype(int) * 255 | |
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') | |
### | |
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255)) | |
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) | |
st.image(final_image, caption=f"Masked Image {i+1}", width=width) | |
if __name__ == "__main__": | |
main() |