|
import os |
|
import streamlit as st |
|
import io |
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from transformers import SamModel, SamConfig, SamProcessor |
|
from PIL import Image |
|
|
|
|
|
CACHE_DIR = "./newcache/" |
|
|
|
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR) |
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR) |
|
|
|
|
|
my_mito_model = SamModel(config=model_config) |
|
|
|
my_mito_model.load_state_dict(torch.load("mito_model_checkpoint.pth", map_location=torch.device('cpu'))) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
my_mito_model.to(device) |
|
|
|
|
|
def get_bounding_box(ground_truth_map): |
|
|
|
y_indices, x_indices = np.where(ground_truth_map > 0) |
|
x_min, x_max = np.min(x_indices), np.max(x_indices) |
|
y_min, y_max = np.min(y_indices), np.max(y_indices) |
|
|
|
H, W = ground_truth_map.shape |
|
x_min = max(0, x_min - np.random.randint(0, 20)) |
|
x_max = min(W, x_max + np.random.randint(0, 20)) |
|
y_min = max(0, y_min - np.random.randint(0, 20)) |
|
y_max = min(H, y_max + np.random.randint(0, 20)) |
|
bbox = [x_min, y_min, x_max, y_max] |
|
return bbox |
|
|
|
|
|
def segment_with_medsam(image, mask_np, prompt_flag): |
|
bbox = get_bounding_box(mask_np) |
|
|
|
points = [[18.062770562770567, 252.59090909090907], |
|
[25.681818181818187, 224.19264069264068], |
|
[42.305194805194816, 195.79437229437227], |
|
[58.928571428571445, 176.40043290043286], |
|
[72.7813852813853, 167.39610389610385], |
|
[91.482683982684, 156.31385281385278], |
|
[112.26190476190479, 152.1580086580086], |
|
[128.1926406926407, 154.9285714285714], |
|
[144.12337662337666, 155.6212121212121], |
|
[157.2835497835498, 158.39177489177487], |
|
[164.90259740259745, 161.85497835497833], |
|
[179.44805194805195, 169.47402597402595], |
|
[189.83766233766238, 174.3225108225108], |
|
[198.8419913419914, 180.55627705627703], |
|
[209.92424242424244, 190.94588744588742], |
|
[220.31385281385286, 196.48701298701297], |
|
[230.70346320346323, 202.7207792207792], |
|
[239.01515151515156, 211.0324675324675], |
|
[250.0974025974026, 219.3441558441558]] |
|
|
|
if prompt_flag: |
|
inputs = processor(image, input_boxes=[[bbox]], input_points = [[points]], return_tensors="pt") |
|
else: |
|
inputs = processor(image, input_boxes=[[bbox]], return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = my_mito_model(**inputs, multimask_output=False) |
|
|
|
|
|
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) |
|
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() |
|
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) |
|
|
|
binary_image_array_uint8 = (medsam_seg * 255).astype(np.uint8) |
|
image = Image.fromarray(binary_image_array_uint8) |
|
image = image.convert('L') |
|
return image |
|
|
|
def main(): |
|
""" |
|
This function defines the Shiny app layout and logic. |
|
""" |
|
uploaded_file_1 = st.file_uploader("Upload Test Image", type="tiff") |
|
uploaded_file_2 = st.file_uploader("Upload Ground Truth to prompt a Bounding Box", type="tiff") |
|
|
|
if uploaded_file_1 is not None and uploaded_file_2 is not None: |
|
tiff_image = Image.open(uploaded_file_1) |
|
tiff_mask = Image.open(uploaded_file_2) |
|
|
|
mask_np = np.array(tiff_mask) |
|
|
|
|
|
segmentation_mask_no_prompt = segment_with_medsam(tiff_image, mask_np, False) |
|
segmentation_mask_with_prompt = segment_with_medsam(tiff_image, mask_np, True) |
|
st.subheader("Segmentation Results") |
|
st.image(tiff_image, caption="Uploaded Image") |
|
st.image(segmentation_mask_no_prompt, caption="Segmented Image") |
|
st.image(segmentation_mask_with_prompt, caption="Segmented Image with occlusion fixed") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|