Milestone4 / app.py
jasoncordova's picture
Update app.py
98053af verified
import ipyleaflet as L
from transformers import SamModel, SamConfig, SamProcessor
import torch
from faicons import icon_svg
from geopy.distance import geodesic, great_circle
from shiny import reactive
from shiny.express import input, render, ui
from shinywidgets import render_widget
import numpy as np
import ipywidgets as widgets
import io
import base64
from PIL import Image
import matplotlib.pyplot as plt
ui.tags.style(
"#file1_progress { height: 100%; }",
".bslib-sidebar-layout {--_sidebar-width: 360px !important; }",
" img { object-fit: contain; }",
)
ui.page_opts(title="Segment Anything Model: Sidewalk Masking", fillable=True)
{"class": "bslib-page-dashboard"}
with ui.sidebar():
ui.input_file("file1", "Upload Image", accept=[".jpg", ".png", ".jpeg"], multiple=False),
ui.input_dark_mode(mode="dark")
with ui.card():
ui.card_header("Finalized Segment")
@render.text
def slider_val():
if input.file1() is None:
return None
else:
return "Here is the prediction mask:"
# return input.file1()[0]['datapath']
def getSegments():
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
my_mito_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
my_mito_model.load_state_dict(torch.load("../modelv2.pth"))
device = "cuda" if torch.cuda.is_available() else "cpu"
my_mito_model.to(device)
# Define the size of your array
array_size = 256
# Define the size of your grid
grid_size = 10
# Generate the grid points
x = np.linspace(0, array_size-1, grid_size)
y = np.linspace(0, array_size-1, grid_size)
# Generate a grid of coordinates
xv, yv = np.meshgrid(x, y)
# Convert the numpy arrays to lists
xv_list = xv.tolist()
yv_list = yv.tolist()
# Combine the x and y coordinates into a list of list of lists
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
inputs = processor(Image.open(input.file1()[0]['datapath']), input_points=input_points, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
my_mito_model.eval()
# forward pass
with torch.no_grad():
outputs = my_mito_model(**inputs, multimask_output=False)
# apply sigmoid
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return single_patch_prediction
# @render.image
# def render_image():
# # Get the uploaded file
# uploaded_file = input.file1()
# # If there is no uploaded file, return None
# if uploaded_file is None:
# return None
# # Read the image file
# imagePath = uploaded_file[0]['datapath']
# # processImage()
# return {"src": imagePath, "width": "100%"}
@render.image
def render_image():
# Get the uploaded file
uploaded_file = input.file1()
# If there is no uploaded file, return None
if uploaded_file is None:
return None
# Call getSegments to get the segmented image numpy array
segmented_image = np.array(getSegments())
segmented_image = segmented_image * 255
colorArray = segmented_image.astype(np.uint8)
image = Image.fromarray(colorArray)
imagePath = "test.jpg"
image.save(imagePath)
return {"src": imagePath, "height": "100%", "class": "contain"}