xai-ViT / app.py
Duckin's picture
Update app.py
64e6f71
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
from PIL import Image
from collections import OrderedDict
from Transformer_Explainability.modules.layers_ours import *
from Transformer_Explainability.baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from Transformer_Explainability.baselines.ViT.ViT_explanation_generator import LRP
from Transformer_Explainability.XAI import *
from io import BytesIO
from PIL import Image
import streamlit as st
import pandas as pd
from streamlit_option_menu import option_menu
# import Pages as pg
# mean = [0.485, 0.456, 0.406] # our images have three channels
# std = [0.229, 0.224, 0.225]
# normalize = transforms.Normalize(mean=mean, std=std)
# transform = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# normalize,
# ])
# resizing = transforms.Resize((224,224))
CLS2IDX= {0: "Benign",
1: "Ductal carcinoma",
2: "Lobular carcinoma",
3: "Mucinous carcinoma",
4: "Papillary carcinoma"}
st.set_page_config(
page_title="Breast Cancer Diagnosis",
# page_icon="πŸ‘¨β€βš•οΈ",
layout="wide",
initial_sidebar_state="expanded", #collapsed
)
# "st.session_state :", st.session_state
if "loop" not in st.session_state:
st.session_state["loop"] = 1
def use_model(upl, mag):
num_classes=5
model = vit_LRP(pretrained=True)
num_ftrs = model.head.in_features
model.head = Linear(num_ftrs, num_classes)
state_dict = torch.load("./Model/V" + mag + ".pth", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
ag = LRP(model)
model_pred=[]
with st.container():
st.subheader("Result")
for file in upl:
bytes_data = file.read()
img = Image.open(BytesIO(bytes_data))
inputs = tf(img)
outputs = model(inputs.unsqueeze(0))
prob = torch.softmax(outputs, dim=1)
class_indices = outputs.data.topk(5, dim=1)[1][0].tolist()
model_prop = []
for cls_idx in class_indices:
model_prop.append([CLS2IDX[cls_idx], "{:.2f}".format(prob[0, cls_idx]*100)])
img_col1, img_col2, txt_col = st.columns((1, 1, 2))
with img_col1:
st.image(
transforms.Resize((224,224))(img),
output_format="auto",
caption="original image"
)
with img_col2:
st.image(
generate_visualization(inputs, ag),
caption="explainability"
)
with txt_col:
df = pd.DataFrame(
model_prop,
columns=["Class", "Probability"]
)
st.dataframe(
df,
width=600
)
rslt = CLS2IDX[class_indices[0]]
if rslt == "Benign":
tp = "-"
else:
tp = rslt
rslt = "Malignant"
model_pred.append([file.name, rslt, tp])
with st.container():
st.subheader("Conclusion")
df2 = pd.DataFrame(
model_pred, columns=["Name", "Result", "Type"]
)
with st.container():
st.dataframe(
df2,
width=800
)
# --Header--
with st.container():
st.title("Breast cancer diagnosis by transformer model")
st.subheader("Hi :wave:")
st.write("This website will help you to diagnose breast cancer images that collected by surgical open biopsy method(SOB)")
st.write("The data that you will input must same to BREAKHIS dataset [Learn more >](https://docs.google.com/document/d/12NYTJkh2yKdR75XQhD3OwQoYFFwqIK7mkSVbd0-VWok/edit?usp=drivesdk)")
# --Sidebar--
with st.sidebar:
# --Inputs--
with st.container():
st.subheader("Inputs")
maglist = ["--Select--", "40X", "100X", "200X", "400X"]
magnification = st.sidebar.selectbox("Select the Magnification", maglist)
with st.container():
uploaded = st.file_uploader(
"Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True
)
# with st.container():
# diagnose = st.button("Diagnose")
with st.container():
if uploaded != [] and magnification != "--Select--":
use_model(uploaded, magnification)
st.session_state["loop"]+=1
# "st.session_state obj :", st.session_state
# pass
# selected = option_menu(
# menu_title=None, # required
# options=["Home", "How to use", "Contact"]
# )
# if selected == "Home":
# inputs()
# elif selected == "How to use":
# pg.howtouse()
# elif selected == "Contact":
# pg.contact()
# st.set_page_config(
# page_title="Breast Cancer Diagnosis",
# page_icon="πŸ‘¨β€βš•οΈ",
# layout="wide",
# initial_sidebar_state="expanded", #collapsed
# )
# st.title("Breast Cancer Diagnosis BY Transformer Model")
# st.sidebar.subheader("Input")
# "st.session_state obj :", st.session_state
# if 'boolean' not in st.session_state:
# st.session_state['boolean'] = False
# "st.session_state obj :", st.session_state
# models_list = ["--Select--", "40X", "100X", "200X", "400X"]
# magnitude = st.sidebar.selectbox("Select the Magnification", models_list)
# uploaded_file = st.sidebar.file_uploader(
# "Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True
# )
# diagnosis = st.sidebar.button("Diagnose")
# if diagnosis:
# st.session_state['boolean'] = True
# num_classes = 5
# model = vit_LRP(pretrained=True)
# num_ftrs = model.head.in_features
# model.head = Linear(num_ftrs, num_classes)
# if uploaded_file != [] and st.session_state['boolean'] and magnitude != "--Select--":
# state_dict = torch.load("./Model/V" + magnitude + ".pth", map_location=torch.device('cpu'))
# model.load_state_dict(state_dict)
# model.eval()
# ag = LRP(model)
# predictions = []
# xai = []
# for each_file in uploaded_file:
# bytes_data = each_file.read()
# img = Image.open(BytesIO(bytes_data))
# inputs = transform(img)
# outputs = model(inputs.unsqueeze(0))
# model_predicted = CLS2IDX[print_top_classes(outputs)[0]]
# if model_predicted == "benign":
# type = '-'
# else:
# type = model_predicted
# model_predicted = "malignant"
# xai.append(generate_visualization(inputs, ag))
# predictions.append([img, each_file.name, model_predicted, type])
# #st.write(each_file)
# #st.write(each_file.name)
# # st.image(img)
# #st.write(model_predicted)
# def showim():
# df = pd.DataFrame(
# predictions, columns=["Image", "Name", "Result", "Type"]
# )
# st.dataframe(
# df.iloc[:,1:], width=1000
# )
# idx = st.select_slider("Select the index of picture", df.index, on_change=showim)
# st.write(idx)
# # st.image(predictions[index][0])
# st.image(xai[idx])
# showim()
# print(uploaded_file, diagnosis, magnitude)