Kaushik Mellacheruvu
Update app.py
28e5dbf verified
import os
import streamlit as st
import io
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
CACHE_DIR = "./newcache/"
# Load model configuration and processor (replace with your model names)
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR)
# Create the model architecture
my_mito_model = SamModel(config=model_config)
# Load your model weights
my_mito_model.load_state_dict(torch.load("mito_model_checkpoint.pth", map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
my_mito_model.to(device)
#Get bounding boxes from mask.
def get_bounding_box(ground_truth_map):
# get bounding box from mask
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
# Function to perform MedSAM segmentation on an image
def segment_with_medsam(image, mask_np, prompt_flag):
bbox = get_bounding_box(mask_np)
points = [[18.062770562770567, 252.59090909090907],
[25.681818181818187, 224.19264069264068],
[42.305194805194816, 195.79437229437227],
[58.928571428571445, 176.40043290043286],
[72.7813852813853, 167.39610389610385],
[91.482683982684, 156.31385281385278],
[112.26190476190479, 152.1580086580086],
[128.1926406926407, 154.9285714285714],
[144.12337662337666, 155.6212121212121],
[157.2835497835498, 158.39177489177487],
[164.90259740259745, 161.85497835497833],
[179.44805194805195, 169.47402597402595],
[189.83766233766238, 174.3225108225108],
[198.8419913419914, 180.55627705627703],
[209.92424242424244, 190.94588744588742],
[220.31385281385286, 196.48701298701297],
[230.70346320346323, 202.7207792207792],
[239.01515151515156, 211.0324675324675],
[250.0974025974026, 219.3441558441558]]
if prompt_flag:
inputs = processor(image, input_boxes=[[bbox]], input_points = [[points]], return_tensors="pt")
else:
inputs = processor(image, input_boxes=[[bbox]], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = my_mito_model(**inputs, multimask_output=False)
# Apply sigmoid and convert to mask
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
binary_image_array_uint8 = (medsam_seg * 255).astype(np.uint8)
image = Image.fromarray(binary_image_array_uint8)
image = image.convert('L')
return image
def main():
"""
This function defines the Shiny app layout and logic.
"""
uploaded_file_1 = st.file_uploader("Upload Test Image", type="tiff")
uploaded_file_2 = st.file_uploader("Upload Ground Truth to prompt a Bounding Box", type="tiff")
if uploaded_file_1 is not None and uploaded_file_2 is not None:
tiff_image = Image.open(uploaded_file_1)
tiff_mask = Image.open(uploaded_file_2)
mask_np = np.array(tiff_mask)
# Perform segmentation
segmentation_mask_no_prompt = segment_with_medsam(tiff_image, mask_np, False)
segmentation_mask_with_prompt = segment_with_medsam(tiff_image, mask_np, True)
st.subheader("Segmentation Results")
st.image(tiff_image, caption="Uploaded Image")
st.image(segmentation_mask_no_prompt, caption="Segmented Image")
st.image(segmentation_mask_with_prompt, caption="Segmented Image with occlusion fixed")
if __name__ == "__main__":
main()