File size: 4,106 Bytes
98053af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3af15
 
98053af
 
6f3af15
98053af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import ipyleaflet as L
from transformers import SamModel, SamConfig, SamProcessor
import torch
from faicons import icon_svg
from geopy.distance import geodesic, great_circle
from shiny import reactive
from shiny.express import input, render, ui
from shinywidgets import render_widget
import numpy as np
import ipywidgets as widgets
import io
import base64
from PIL import Image
import matplotlib.pyplot as plt

ui.tags.style(
    "#file1_progress { height: 100%; }",
    ".bslib-sidebar-layout {--_sidebar-width: 360px !important; }",
    " img { object-fit: contain; }",
)

ui.page_opts(title="Segment Anything Model: Sidewalk Masking", fillable=True)
{"class": "bslib-page-dashboard"}

with ui.sidebar():

    ui.input_file("file1", "Upload Image", accept=[".jpg", ".png", ".jpeg"], multiple=False),
    ui.input_dark_mode(mode="dark")

with ui.card():

    ui.card_header("Finalized Segment")

    @render.text
    def slider_val():
        if input.file1() is None:
            return None
        else:
            return "Here is the prediction mask:"
            # return input.file1()[0]['datapath']

    def getSegments():
        # Load the model configuration
        model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

        # Create an instance of the model architecture with the loaded configuration
        my_mito_model = SamModel(config=model_config)
        #Update the model by loading the weights from saved file.
        my_mito_model.load_state_dict(torch.load("../modelv2.pth"))

        device = "cuda" if torch.cuda.is_available() else "cpu"
        my_mito_model.to(device)

        # Define the size of your array
        array_size = 256

        # Define the size of your grid
        grid_size = 10

        # Generate the grid points
        x = np.linspace(0, array_size-1, grid_size)
        y = np.linspace(0, array_size-1, grid_size)

        # Generate a grid of coordinates
        xv, yv = np.meshgrid(x, y)

        # Convert the numpy arrays to lists
        xv_list = xv.tolist()
        yv_list = yv.tolist()

        # Combine the x and y coordinates into a list of list of lists
        input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
        input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)

        inputs = processor(Image.open(input.file1()[0]['datapath']), input_points=input_points, return_tensors="pt")

        inputs = {k: v.to(device) for k, v in inputs.items()}
        my_mito_model.eval()

        # forward pass
        with torch.no_grad():
            outputs = my_mito_model(**inputs, multimask_output=False)

        # apply sigmoid
        single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        # convert soft mask to hard mask
        single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
        single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
        return single_patch_prediction

    # @render.image
    # def render_image():
    #     # Get the uploaded file
    #     uploaded_file = input.file1()

    #     # If there is no uploaded file, return None
    #     if uploaded_file is None:
    #         return None

    #     # Read the image file
    #     imagePath = uploaded_file[0]['datapath']

    #     # processImage()
    #     return {"src": imagePath, "width": "100%"}

    @render.image
    def render_image():
        # Get the uploaded file
        uploaded_file = input.file1()

        # If there is no uploaded file, return None
        if uploaded_file is None:
            return None

        # Call getSegments to get the segmented image numpy array
        segmented_image = np.array(getSegments())
        segmented_image = segmented_image * 255
        colorArray = segmented_image.astype(np.uint8)
        image = Image.fromarray(colorArray)
        
        imagePath = "test.jpg"
        image.save(imagePath)

        return {"src": imagePath, "height": "100%", "class": "contain"}