import os import streamlit as st import io import numpy as np import torch import matplotlib.pyplot as plt from transformers import SamModel, SamConfig, SamProcessor from PIL import Image CACHE_DIR = "./newcache/" # Load model configuration and processor (replace with your model names) model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR) processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=CACHE_DIR) # Create the model architecture my_mito_model = SamModel(config=model_config) # Load your model weights my_mito_model.load_state_dict(torch.load("mito_model_checkpoint.pth", map_location=torch.device('cpu'))) device = "cuda" if torch.cuda.is_available() else "cpu" my_mito_model.to(device) #Get bounding boxes from mask. def get_bounding_box(ground_truth_map): # get bounding box from mask y_indices, x_indices = np.where(ground_truth_map > 0) x_min, x_max = np.min(x_indices), np.max(x_indices) y_min, y_max = np.min(y_indices), np.max(y_indices) # add perturbation to bounding box coordinates H, W = ground_truth_map.shape x_min = max(0, x_min - np.random.randint(0, 20)) x_max = min(W, x_max + np.random.randint(0, 20)) y_min = max(0, y_min - np.random.randint(0, 20)) y_max = min(H, y_max + np.random.randint(0, 20)) bbox = [x_min, y_min, x_max, y_max] return bbox # Function to perform MedSAM segmentation on an image def segment_with_medsam(image, mask_np, prompt_flag): bbox = get_bounding_box(mask_np) points = [[18.062770562770567, 252.59090909090907], [25.681818181818187, 224.19264069264068], [42.305194805194816, 195.79437229437227], [58.928571428571445, 176.40043290043286], [72.7813852813853, 167.39610389610385], [91.482683982684, 156.31385281385278], [112.26190476190479, 152.1580086580086], [128.1926406926407, 154.9285714285714], [144.12337662337666, 155.6212121212121], [157.2835497835498, 158.39177489177487], [164.90259740259745, 161.85497835497833], [179.44805194805195, 169.47402597402595], [189.83766233766238, 174.3225108225108], [198.8419913419914, 180.55627705627703], [209.92424242424244, 190.94588744588742], [220.31385281385286, 196.48701298701297], [230.70346320346323, 202.7207792207792], [239.01515151515156, 211.0324675324675], [250.0974025974026, 219.3441558441558]] if prompt_flag: inputs = processor(image, input_boxes=[[bbox]], input_points = [[points]], return_tensors="pt") else: inputs = processor(image, input_boxes=[[bbox]], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Perform inference with torch.no_grad(): outputs = my_mito_model(**inputs, multimask_output=False) # Apply sigmoid and convert to mask medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) binary_image_array_uint8 = (medsam_seg * 255).astype(np.uint8) image = Image.fromarray(binary_image_array_uint8) image = image.convert('L') return image def main(): """ This function defines the Shiny app layout and logic. """ uploaded_file_1 = st.file_uploader("Upload Test Image", type="tiff") uploaded_file_2 = st.file_uploader("Upload Ground Truth to prompt a Bounding Box", type="tiff") if uploaded_file_1 is not None and uploaded_file_2 is not None: tiff_image = Image.open(uploaded_file_1) tiff_mask = Image.open(uploaded_file_2) mask_np = np.array(tiff_mask) # Perform segmentation segmentation_mask_no_prompt = segment_with_medsam(tiff_image, mask_np, False) segmentation_mask_with_prompt = segment_with_medsam(tiff_image, mask_np, True) st.subheader("Segmentation Results") st.image(tiff_image, caption="Uploaded Image") st.image(segmentation_mask_no_prompt, caption="Segmented Image") st.image(segmentation_mask_with_prompt, caption="Segmented Image with occlusion fixed") if __name__ == "__main__": main()