File size: 1,803 Bytes
52dce0a
 
 
 
 
915d664
92badd3
 
7008c4a
52dce0a
 
83fe5b0
 
52dce0a
 
 
 
 
 
 
 
 
55a5bfa
52dce0a
55a5bfa
3b197de
52dce0a
 
 
 
 
55a5bfa
 
 
9d3accc
 
e01d5f6
9d3accc
 
f03708e
 
d152f7f
 
1d5816e
 
86680ae
0a1d185
d152f7f
99ab932
d152f7f
e01d5f6
 
d152f7f
7c72540
d152f7f
 
 
 
 
 
 
 
1d5816e
52dce0a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pathlib import Path
from typing import List, Dict, Tuple
import pandas as pd
import seaborn as sns
import shinyswatch
import run
import os
os.environ['MPLCONFIGDIR'] = "/code/configs"
import matplotlib.pyplot as plt

from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch

sns.set_theme()

www_dir = Path(__file__).parent.resolve() / "www"

app_ui = ui.page_fillable(
    shinyswatch.theme.minty(),
    ui.layout_sidebar(
        ui.sidebar(
            ui.input_file("image_input", "Upload image: ", multiple=True),
        ),
        ui.output_image("image"),
        ui.output_plot("plot_output", fill=True),
    ),
)


def server(input: Inputs, output: Outputs, session: Session):
    @output
    @render.image
    def image():
        if input.image_input():
            src = input.image_input()[0]['datapath']
            img = {"src": src, "width": "500px"} 
            return img
        return None

    @output
    @render.plot
    def plot_output():
        if input.image_input():
            src = input.image_input()[0]['datapath']
            prob, prediction = run.pred(src)
            fig, axes = plt.subplots(1, 2, figsize=(60, 30))

            im = axes[0].imshow(prob, cmap='viridis')
            axes[0].set_title("Probability Map")
            cbar = fig.colorbar(im, ax=axes[0])
            

            axes[1].imshow(prediction, cmap='gray') 
            axes[1].set_title("Prediction")

            for ax in axes:
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])
            return fig
        return None


app = App(
    app_ui,
    server,
    static_assets=str(www_dir),
)