import ipyleaflet as L from transformers import SamModel, SamConfig, SamProcessor import torch from faicons import icon_svg from geopy.distance import geodesic, great_circle from shiny import reactive from shiny.express import input, render, ui from shinywidgets import render_widget import numpy as np import ipywidgets as widgets import io import base64 from PIL import Image import matplotlib.pyplot as plt ui.tags.style( "#file1_progress { height: 100%; }", ".bslib-sidebar-layout {--_sidebar-width: 360px !important; }", " img { object-fit: contain; }", ) ui.page_opts(title="Segment Anything Model: Sidewalk Masking", fillable=True) {"class": "bslib-page-dashboard"} with ui.sidebar(): ui.input_file("file1", "Upload Image", accept=[".jpg", ".png", ".jpeg"], multiple=False), ui.input_dark_mode(mode="dark") with ui.card(): ui.card_header("Finalized Segment") @render.text def slider_val(): if input.file1() is None: return None else: return "Here is the prediction mask:" # return input.file1()[0]['datapath'] def getSegments(): # Load the model configuration model_config = SamConfig.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Create an instance of the model architecture with the loaded configuration my_mito_model = SamModel(config=model_config) #Update the model by loading the weights from saved file. my_mito_model.load_state_dict(torch.load("../modelv2.pth")) device = "cuda" if torch.cuda.is_available() else "cpu" my_mito_model.to(device) # Define the size of your array array_size = 256 # Define the size of your grid grid_size = 10 # Generate the grid points x = np.linspace(0, array_size-1, grid_size) y = np.linspace(0, array_size-1, grid_size) # Generate a grid of coordinates xv, yv = np.meshgrid(x, y) # Convert the numpy arrays to lists xv_list = xv.tolist() yv_list = yv.tolist() # Combine the x and y coordinates into a list of list of lists input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) inputs = processor(Image.open(input.file1()[0]['datapath']), input_points=input_points, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} my_mito_model.eval() # forward pass with torch.no_grad(): outputs = my_mito_model(**inputs, multimask_output=False) # apply sigmoid single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) # convert soft mask to hard mask single_patch_prob = single_patch_prob.cpu().numpy().squeeze() single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) return single_patch_prediction # @render.image # def render_image(): # # Get the uploaded file # uploaded_file = input.file1() # # If there is no uploaded file, return None # if uploaded_file is None: # return None # # Read the image file # imagePath = uploaded_file[0]['datapath'] # # processImage() # return {"src": imagePath, "width": "100%"} @render.image def render_image(): # Get the uploaded file uploaded_file = input.file1() # If there is no uploaded file, return None if uploaded_file is None: return None # Call getSegments to get the segmented image numpy array segmented_image = np.array(getSegments()) segmented_image = segmented_image * 255 colorArray = segmented_image.astype(np.uint8) image = Image.fromarray(colorArray) imagePath = "test.jpg" image.save(imagePath) return {"src": imagePath, "height": "100%", "class": "contain"}