from pathlib import Path from typing import List, Dict, Tuple import pandas as pd import seaborn as sns import shinyswatch import run import os os.environ['MPLCONFIGDIR'] = "/code/configs" import matplotlib.pyplot as plt from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui from transformers import SamModel, SamConfig, SamProcessor import torch sns.set_theme() www_dir = Path(__file__).parent.resolve() / "www" app_ui = ui.page_fillable( shinyswatch.theme.minty(), ui.layout_sidebar( ui.sidebar( ui.input_file("image_input", "Upload image: ", multiple=True), ), ui.output_image("image"), ui.output_plot("plot_output", fill=True), ), ) def server(input: Inputs, output: Outputs, session: Session): @output @render.image def image(): if input.image_input(): src = input.image_input()[0]['datapath'] img = {"src": src, "width": "500px"} return img return None @output @render.plot def plot_output(): if input.image_input(): src = input.image_input()[0]['datapath'] prob, prediction = run.pred(src) fig, axes = plt.subplots(1, 2, figsize=(60, 30)) im = axes[0].imshow(prob, cmap='viridis') axes[0].set_title("Probability Map") cbar = fig.colorbar(im, ax=axes[0]) axes[1].imshow(prediction, cmap='gray') axes[1].set_title("Prediction") for ax in axes: ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) return fig return None app = App( app_ui, server, static_assets=str(www_dir), )