Trang Dang
adjust size
29f785c
raw history blame
No virus
1.76 kB
from pathlib import Path
from typing import List, Dict, Tuple
import pandas as pd
import seaborn as sns
import shinyswatch
import run
import os
os.environ['MPLCONFIGDIR'] = "/code/configs"
import matplotlib.pyplot as plt
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch
sns.set_theme()
www_dir = Path(__file__).parent.resolve() / "www"
app_ui = ui.page_fillable(
shinyswatch.theme.minty(),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("image_input", "Upload image: ", multiple=True),
),
ui.output_image("image"),
ui.output_plot("plot_output"),
),
)
def server(input: Inputs, output: Outputs, session: Session):
@output
@render.image
def image():
if input.image_input():
src = input.image_input()[0]['datapath']
img = {"src": src, "width": "500px"}
return img
return None
@output
@render.plot
def plot_output():
if input.image_input():
src = input.image_input()[0]['datapath']
prob, prediction = run.pred(src)
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].imshow(prob, cmap='gray')
axes[0].set_title("Probability Map")
im = axes[1].imshow(prediction)
axes[1].set_title("Prediction")
cbar = fig.colorbar(im, ax=axes[1])
for ax in axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
return fig
return None
app = App(
app_ui,
server,
static_assets=str(www_dir),
)