import torch from transformers import SamConfig, SamProcessor, SamModel from shiny import App, Inputs, Outputs, Session, render, ui, reactive import numpy as np from PIL import Image import io import base64 # Load model configuration model_config = SamConfig.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Create an instance of the model my_model = SamModel(config=model_config) my_model.load_state_dict(torch.load("models/model_checkpoint_trained_on_train.pth", map_location=torch.device('cpu'))) my_model.eval() def predict(image): image_array = np.array(image) # Default bounding box to entire frame prompt = [0, 0, image_array.shape[1], image_array.shape[0]] inputs = processor( images=image_array, input_boxes=[[prompt]] if prompt else None, return_tensors="pt" ) with torch.no_grad(): outputs = my_model(**inputs, multimask_output=False) pred_masks = outputs.pred_masks.squeeze() seg_prob = torch.sigmoid(pred_masks).cpu().numpy() # Convert soft to hard mask with a threshold hard_mask = (seg_prob > 0.5).astype(np.uint8) # Prepare probability map in color seg_prob_scaled = np.stack((seg_prob * 255,) * 3, axis=-1).astype(np.uint8) prob_image = Image.fromarray(seg_prob_scaled) return None, hard_mask, prob_image def display_images_as_data_url(img): buffer = io.BytesIO() img.save(buffer, format="PNG") buffer.seek(0) img_data = buffer.read() return f"data:image/png;base64,{base64.b64encode(img_data).decode()}" def server(input: Inputs, output: Outputs, Session: Session): @reactive.Calc def segmented_data(): uploaded_file = input.upload_image() if uploaded_file: uploaded_file_path = uploaded_file[0]['datapath'] image = Image.open(uploaded_file_path).convert('RGB') gt_mask, pred_mask, prob_image = predict(image) return image, gt_mask, pred_mask, prob_image else: return None, None, None, None @output @render.ui def segmented_result(): result = segmented_data() if result == (None, None, None, None): return "Upload an image to start segmentation." image, gt_mask, pred_mask, prob_image = result resized_image = image.resize((256, 256)) original_image_data = display_images_as_data_url(resized_image) predicted_mask_data = display_images_as_data_url(Image.fromarray(pred_mask * 255)) prob_map_data = display_images_as_data_url(prob_image) ui_result = ui.div( ui.div( ui.img(src=original_image_data, width="100%"), ui.h5("Original Image") ), ui.div( ui.img(src=predicted_mask_data, width="100%"), ui.h5("Predicted Mask") ), ui.div( ui.img(src=prob_map_data, width="100%"), ui.h5("Probability Map") ), style="display: flex;" ) return ui_result # Colors in HEX format #Purple: #800080 #White: #FFFFFF #Orange: #FFA500 app_ui = ui.page_fillable( ui.layout_sidebar( ui.sidebar( ui.panel_title("Image Segmentation"), ui.input_file("upload_image", "Upload an Image"), style=""" background: linear-gradient(135deg, #FFA500 0%, #800080 100%); color: white; padding: 20px; border-radius: 0; min-height: 100vh; box-shadow: none; """ ), ui.output_ui("segmented_result"), style=""" background: #f8f9fa; padding: 20px; height: 100vh; """ ), style=""" background: linear-gradient(135deg, #800080 0%, #FFFFFF 100%); height: 100vh; """ ) app = App(app_ui, server) if __name__ == "__main__": app.run()