beingcognitive's picture
Update app.py
31361ed verified
raw
history blame contribute delete
No virus
5.8 kB
import streamlit as st
# from transformers import AutoProcessor, AutoModelForMaskGeneration
from transformers import SamModel, SamProcessor
from transformers import pipeline
from PIL import Image, ImageOps
# from PIL import Image
import numpy as np
# import matplotlib.pyplot as plt
import torch
import requests
from io import BytesIO
def main():
st.title("Image Segmentation with Object Detection")
# Introduction and How-to
st.markdown("""
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works:
- **Upload an image**: Drag and drop or use the browse files option.
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes.
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data.
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis.
Please note that processing takes some time. We appreciate your patience as the models do their work!
""")
# Model credits
st.subheader("Powered by:")
st.write("- Object Detection Model: `facebook/detr-resnet-50`")
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
# Load SAM by Facebook
# processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
# model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
# Load Object Detection
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
xs_ys = [(2.0, 2.0), (2.5, 2.5)] #, (2.5, 2.0), (2.0, 2.5), (1.5, 1.5)]
alpha = 20
width = 600
if uploaded_file is not None:
raw_image = Image.open(uploaded_file)
st.subheader("Uploaded Image")
st.image(raw_image, caption="Uploaded Image", width=width)
### STEP 1. Object Detection
pipeline_output = od_pipe(raw_image)
# Convert the bounding boxes from the pipeline output into the expected format for the SAM processor
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
labels_format = [b['label'] for b in pipeline_output]
print(input_boxes_format)
print(labels_format)
# Now use these formatted boxes with the processor
for b, l in zip(input_boxes_format, labels_format):
with st.spinner('Processing...'):
st.subheader(f'bounding box : {l}')
inputs = processor(images=raw_image,
input_boxes=[b],
return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
predicted_mask = predicted_masks[0]
for i in range(0, 3):
# 2D array (boolean mask)
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
# Apply the mask to the image
# Convert mask to a 3-channel image if your base image is in RGB
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
#display the final image
st.image(final_image, caption=f"Masked Image {i+1}", width=width)
###
for (x, y) in xs_ys:
with st.spinner('Processing...'):
# Calculate input points
point_x = raw_image.size[0] // x
point_y = raw_image.size[1] // y
input_points = [[[ point_x, point_y ]]]
# Prepare inputs
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
predicted_mask = predicted_masks[0]
# Display masked images
st.subheader(f"Input points : ({1/x},{1/y})")
for i in range(3):
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
###
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
st.image(final_image, caption=f"Masked Image {i+1}", width=width)
if __name__ == "__main__":
main()