Spaces:
Sleeping
Sleeping
import ipyleaflet as L | |
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
from faicons import icon_svg | |
from geopy.distance import geodesic, great_circle | |
from shiny import reactive | |
from shiny.express import input, render, ui | |
from shinywidgets import render_widget | |
import numpy as np | |
import ipywidgets as widgets | |
import io | |
import base64 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
ui.tags.style( | |
"#file1_progress { height: 100%; }", | |
".bslib-sidebar-layout {--_sidebar-width: 360px !important; }", | |
" img { object-fit: contain; }", | |
) | |
ui.page_opts(title="Segment Anything Model: Sidewalk Masking", fillable=True) | |
{"class": "bslib-page-dashboard"} | |
with ui.sidebar(): | |
ui.input_file("file1", "Upload Image", accept=[".jpg", ".png", ".jpeg"], multiple=False), | |
ui.input_dark_mode(mode="dark") | |
with ui.card(): | |
ui.card_header("Finalized Segment") | |
def slider_val(): | |
if input.file1() is None: | |
return None | |
else: | |
return "Here is the prediction mask:" | |
# return input.file1()[0]['datapath'] | |
def getSegments(): | |
# Load the model configuration | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Create an instance of the model architecture with the loaded configuration | |
my_mito_model = SamModel(config=model_config) | |
#Update the model by loading the weights from saved file. | |
my_mito_model.load_state_dict(torch.load("../modelv2.pth")) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
my_mito_model.to(device) | |
# Define the size of your array | |
array_size = 256 | |
# Define the size of your grid | |
grid_size = 10 | |
# Generate the grid points | |
x = np.linspace(0, array_size-1, grid_size) | |
y = np.linspace(0, array_size-1, grid_size) | |
# Generate a grid of coordinates | |
xv, yv = np.meshgrid(x, y) | |
# Convert the numpy arrays to lists | |
xv_list = xv.tolist() | |
yv_list = yv.tolist() | |
# Combine the x and y coordinates into a list of list of lists | |
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] | |
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) | |
inputs = processor(Image.open(input.file1()[0]['datapath']), input_points=input_points, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
my_mito_model.eval() | |
# forward pass | |
with torch.no_grad(): | |
outputs = my_mito_model(**inputs, multimask_output=False) | |
# apply sigmoid | |
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
single_patch_prob = single_patch_prob.cpu().numpy().squeeze() | |
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) | |
return single_patch_prediction | |
# @render.image | |
# def render_image(): | |
# # Get the uploaded file | |
# uploaded_file = input.file1() | |
# # If there is no uploaded file, return None | |
# if uploaded_file is None: | |
# return None | |
# # Read the image file | |
# imagePath = uploaded_file[0]['datapath'] | |
# # processImage() | |
# return {"src": imagePath, "width": "100%"} | |
def render_image(): | |
# Get the uploaded file | |
uploaded_file = input.file1() | |
# If there is no uploaded file, return None | |
if uploaded_file is None: | |
return None | |
# Call getSegments to get the segmented image numpy array | |
segmented_image = np.array(getSegments()) | |
segmented_image = segmented_image * 255 | |
colorArray = segmented_image.astype(np.uint8) | |
image = Image.fromarray(colorArray) | |
imagePath = "test.jpg" | |
image.save(imagePath) | |
return {"src": imagePath, "height": "100%", "class": "contain"} |