File size: 2,935 Bytes
47b4865
d56649a
 
 
 
 
 
 
 
 
 
 
47b4865
 
 
d56649a
 
 
 
 
 
 
47b4865
 
d56649a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import asyncio
import matplotlib.pyplot as plt
import numpy as np
from transformers import SamModel, SamConfig, SamProcessor
import torch
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from PIL import Image

app_ui = ui.page_fluid(
    ui.input_file("file1", "Upload Tile image for sidewalk segmentation", accept=".tif", multiple=False),
    ui.output_plot("mask"),  # Changed from ui.output_table to ui.output_plot based on the context of output
)

def server(input: Inputs, output: Outputs, session: Session):
    @reactive.calc
    def parsed_file():
        file_info = input.file1()
        if file_info is None or len(file_info) == 0:
            return None
        return file_info[0]["datapath"]
    
    @output
    @render.plot
    async def mask():
        filepath = parsed_file()
        if filepath is None:
            return
        print(filepath)
        # Assuming the model and processor are correctly configured
        model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
        my_sidewalk_model = SamModel(model_config)
        my_sidewalk_model.load_state_dict(torch.load("./sidwalk_model_checkpoint.pth", map_location='cpu'))
        device = torch.device("cpu")
        my_sidewalk_model.to(device)

        # Load image
        image = Image.open(filepath)
        imarray = np.array(image)
        single_patch = Image.fromarray(imarray)

        inputs = processor(single_patch, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        my_sidewalk_model.eval()
        # Model inference
        with torch.no_grad():
            outputs = my_sidewalk_model(**inputs, multimask_output=False)
        single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        # convert soft mask to hard mask
        single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
        single_patch_prediction = (single_patch_prob > 0).astype(np.uint8)


        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Plot the first image on the left
        axes[0].imshow(np.array(single_patch), cmap='gray')  # Assuming the first image is grayscale
        axes[0].set_title("Image")

        # Plot the second image on the right
        axes[1].imshow(single_patch_prob)  # Assuming the second image is grayscale
        axes[1].set_title("Probability Map")

        # Plot the second image on the right
        axes[2].imshow(single_patch_prediction, cmap='gray')  # Assuming the second image is grayscale
        axes[2].set_title("Prediction")

        # Hide axis ticks and labels
        for ax in axes:
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xticklabels([])
            ax.set_yticklabels([])

        # Display the images side by side
        return fig

    

app = App(app_ui, server)