|
import streamlit as st |
|
|
|
from transformers import SamModel, SamProcessor |
|
from transformers import pipeline |
|
from PIL import Image, ImageOps |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import requests |
|
from io import BytesIO |
|
|
|
def main(): |
|
st.title("Image Segmentation with Object Detection") |
|
|
|
|
|
st.markdown(""" |
|
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works: |
|
- **Upload an image**: Drag and drop or use the browse files option. |
|
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes. |
|
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data. |
|
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis. |
|
|
|
Please note that processing takes some time. We appreciate your patience as the models do their work! |
|
""") |
|
|
|
|
|
st.subheader("Powered by:") |
|
st.write("- Object Detection Model: `facebook/detr-resnet-50`") |
|
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`") |
|
|
|
|
|
|
|
|
|
|
|
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") |
|
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") |
|
|
|
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
xs_ys = [(2.0, 2.0), (2.5, 2.5)] |
|
alpha = 20 |
|
width = 600 |
|
|
|
if uploaded_file is not None: |
|
raw_image = Image.open(uploaded_file) |
|
|
|
st.subheader("Uploaded Image") |
|
st.image(raw_image, caption="Uploaded Image", width=width) |
|
|
|
|
|
pipeline_output = od_pipe(raw_image) |
|
|
|
|
|
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output] |
|
labels_format = [b['label'] for b in pipeline_output] |
|
print(input_boxes_format) |
|
print(labels_format) |
|
|
|
|
|
for b, l in zip(input_boxes_format, labels_format): |
|
with st.spinner('Processing...'): |
|
|
|
st.subheader(f'bounding box : {l}') |
|
inputs = processor(images=raw_image, |
|
input_boxes=[b], |
|
return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
predicted_masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks, |
|
inputs["original_sizes"], |
|
inputs["reshaped_input_sizes"] |
|
) |
|
predicted_mask = predicted_masks[0] |
|
|
|
for i in range(0, 3): |
|
|
|
mask = predicted_mask[0][i] |
|
int_mask = np.array(mask).astype(int) * 255 |
|
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') |
|
|
|
|
|
|
|
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255)) |
|
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) |
|
|
|
|
|
st.image(final_image, caption=f"Masked Image {i+1}", width=width) |
|
|
|
|
|
for (x, y) in xs_ys: |
|
with st.spinner('Processing...'): |
|
|
|
|
|
point_x = raw_image.size[0] // x |
|
point_y = raw_image.size[1] // y |
|
input_points = [[[ point_x, point_y ]]] |
|
|
|
|
|
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
predicted_masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks, |
|
inputs["original_sizes"], |
|
inputs["reshaped_input_sizes"] |
|
) |
|
|
|
predicted_mask = predicted_masks[0] |
|
|
|
|
|
st.subheader(f"Input points : ({1/x},{1/y})") |
|
for i in range(3): |
|
mask = predicted_mask[0][i] |
|
int_mask = np.array(mask).astype(int) * 255 |
|
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') |
|
|
|
|
|
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255)) |
|
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) |
|
|
|
st.image(final_image, caption=f"Masked Image {i+1}", width=width) |
|
|
|
if __name__ == "__main__": |
|
main() |