xai-ViT / app.py
Duckin's picture
Update app.py
64e6f71
raw
history blame contribute delete
No virus
7.24 kB
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
from PIL import Image
from collections import OrderedDict
from Transformer_Explainability.modules.layers_ours import *
from Transformer_Explainability.baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from Transformer_Explainability.baselines.ViT.ViT_explanation_generator import LRP
from Transformer_Explainability.XAI import *
from io import BytesIO
from PIL import Image
import streamlit as st
import pandas as pd
from streamlit_option_menu import option_menu
# import Pages as pg
# mean = [0.485, 0.456, 0.406] # our images have three channels
# std = [0.229, 0.224, 0.225]
# normalize = transforms.Normalize(mean=mean, std=std)
# transform = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# normalize,
# ])
# resizing = transforms.Resize((224,224))
CLS2IDX= {0: "Benign",
1: "Ductal carcinoma",
2: "Lobular carcinoma",
3: "Mucinous carcinoma",
4: "Papillary carcinoma"}
st.set_page_config(
page_title="Breast Cancer Diagnosis",
# page_icon="πŸ‘¨β€βš•οΈ",
layout="wide",
initial_sidebar_state="expanded", #collapsed
)
# "st.session_state :", st.session_state
if "loop" not in st.session_state:
st.session_state["loop"] = 1
def use_model(upl, mag):
num_classes=5
model = vit_LRP(pretrained=True)
num_ftrs = model.head.in_features
model.head = Linear(num_ftrs, num_classes)
state_dict = torch.load("./Model/V" + mag + ".pth", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
ag = LRP(model)
model_pred=[]
with st.container():
st.subheader("Result")
for file in upl:
bytes_data = file.read()
img = Image.open(BytesIO(bytes_data))
inputs = tf(img)
outputs = model(inputs.unsqueeze(0))
prob = torch.softmax(outputs, dim=1)
class_indices = outputs.data.topk(5, dim=1)[1][0].tolist()
model_prop = []
for cls_idx in class_indices:
model_prop.append([CLS2IDX[cls_idx], "{:.2f}".format(prob[0, cls_idx]*100)])
img_col1, img_col2, txt_col = st.columns((1, 1, 2))
with img_col1:
st.image(
transforms.Resize((224,224))(img),
output_format="auto",
caption="original image"
)
with img_col2:
st.image(
generate_visualization(inputs, ag),
caption="explainability"
)
with txt_col:
df = pd.DataFrame(
model_prop,
columns=["Class", "Probability"]
)
st.dataframe(
df,
width=600
)
rslt = CLS2IDX[class_indices[0]]
if rslt == "Benign":
tp = "-"
else:
tp = rslt
rslt = "Malignant"
model_pred.append([file.name, rslt, tp])
with st.container():
st.subheader("Conclusion")
df2 = pd.DataFrame(
model_pred, columns=["Name", "Result", "Type"]
)
with st.container():
st.dataframe(
df2,
width=800
)
# --Header--
with st.container():
st.title("Breast cancer diagnosis by transformer model")
st.subheader("Hi :wave:")
st.write("This website will help you to diagnose breast cancer images that collected by surgical open biopsy method(SOB)")
st.write("The data that you will input must same to BREAKHIS dataset [Learn more >](https://docs.google.com/document/d/12NYTJkh2yKdR75XQhD3OwQoYFFwqIK7mkSVbd0-VWok/edit?usp=drivesdk)")
# --Sidebar--
with st.sidebar:
# --Inputs--
with st.container():
st.subheader("Inputs")
maglist = ["--Select--", "40X", "100X", "200X", "400X"]
magnification = st.sidebar.selectbox("Select the Magnification", maglist)
with st.container():
uploaded = st.file_uploader(
"Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True
)
# with st.container():
# diagnose = st.button("Diagnose")
with st.container():
if uploaded != [] and magnification != "--Select--":
use_model(uploaded, magnification)
st.session_state["loop"]+=1
# "st.session_state obj :", st.session_state
# pass
# selected = option_menu(
# menu_title=None, # required
# options=["Home", "How to use", "Contact"]
# )
# if selected == "Home":
# inputs()
# elif selected == "How to use":
# pg.howtouse()
# elif selected == "Contact":
# pg.contact()
# st.set_page_config(
# page_title="Breast Cancer Diagnosis",
# page_icon="πŸ‘¨β€βš•οΈ",
# layout="wide",
# initial_sidebar_state="expanded", #collapsed
# )
# st.title("Breast Cancer Diagnosis BY Transformer Model")
# st.sidebar.subheader("Input")
# "st.session_state obj :", st.session_state
# if 'boolean' not in st.session_state:
# st.session_state['boolean'] = False
# "st.session_state obj :", st.session_state
# models_list = ["--Select--", "40X", "100X", "200X", "400X"]
# magnitude = st.sidebar.selectbox("Select the Magnification", models_list)
# uploaded_file = st.sidebar.file_uploader(
# "Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True
# )
# diagnosis = st.sidebar.button("Diagnose")
# if diagnosis:
# st.session_state['boolean'] = True
# num_classes = 5
# model = vit_LRP(pretrained=True)
# num_ftrs = model.head.in_features
# model.head = Linear(num_ftrs, num_classes)
# if uploaded_file != [] and st.session_state['boolean'] and magnitude != "--Select--":
# state_dict = torch.load("./Model/V" + magnitude + ".pth", map_location=torch.device('cpu'))
# model.load_state_dict(state_dict)
# model.eval()
# ag = LRP(model)
# predictions = []
# xai = []
# for each_file in uploaded_file:
# bytes_data = each_file.read()
# img = Image.open(BytesIO(bytes_data))
# inputs = transform(img)
# outputs = model(inputs.unsqueeze(0))
# model_predicted = CLS2IDX[print_top_classes(outputs)[0]]
# if model_predicted == "benign":
# type = '-'
# else:
# type = model_predicted
# model_predicted = "malignant"
# xai.append(generate_visualization(inputs, ag))
# predictions.append([img, each_file.name, model_predicted, type])
# #st.write(each_file)
# #st.write(each_file.name)
# # st.image(img)
# #st.write(model_predicted)
# def showim():
# df = pd.DataFrame(
# predictions, columns=["Image", "Name", "Result", "Type"]
# )
# st.dataframe(
# df.iloc[:,1:], width=1000
# )
# idx = st.select_slider("Select the index of picture", df.index, on_change=showim)
# st.write(idx)
# # st.image(predictions[index][0])
# st.image(xai[idx])
# showim()
# print(uploaded_file, diagnosis, magnitude)