import streamlit as st import openslide import os from streamlit_option_menu import option_menu import torch @st.cache(suppress_st_warning=True) def load_model(): from predict import Predictor predictor = Predictor() return predictor #@st.cache(suppress_st_warning=True) #def load_dependencies(): # os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") # os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") # os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") def main(): # environment variables for the inference api os.environ['DATA_DIR'] = 'queries' os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches') os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides') os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots') os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True) # manually put the metadata in the metadata folder os.environ['CLASS_METADATA'] ='metadata/label_map.pkl' # manually put the desired weights in the weights folder os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights' os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth') os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth') #st.set_page_config(page_title="",layout='wide') predictor = load_model()#Predictor() ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool." CONTACT_TEXT = """ _Built by Christian Cancedda and LabLab lads with love_ ❤️ [![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus) [![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda) Star project repository: [![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer) """ VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window" DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease" with st.sidebar: choice = option_menu("LastMinute - Diagnosis", ["About", "Visualize WSI slide", "Cancer Detection", "Contact"], icons=['house', 'upload', 'activity', 'person lines fill'], menu_icon="app-indicator", default_index=0, styles={ # "container": {"padding": "5!important", "background-color": "#fafafa", }, "container": {"border-radius": ".0rem"}, # "icon": {"color": "orange", "font-size": "25px"}, # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px", # "--hover-color": "#eee"}, # "nav-link-selected": {"background-color": "#02ab21"}, } ) st.sidebar.markdown( """
""", unsafe_allow_html=True, ) if choice == "About": st.title(choice) README = requests.get("https://raw.githubusercontent.com/Chris1nexus/inference-graph-transformer/master/README.md").text README = str(README).replace('width="1200"','width="700"') # st.title(choose) st.markdown(README, unsafe_allow_html=True) if choice == "Visualize WSI slide": st.title(choice) st.markdown(VISUALIZE_TEXT) uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") if uploaded_file is not None: ori = openslide.OpenSlide(uploaded_file.name) width, height = ori.dimensions REDUCTION_FACTOR = 20 w, h = int(width/512), int(height/512) w_r, h_r = int(width/20), int(height/20) resized_img = ori.get_thumbnail((w_r,h_r)) resized_img = resized_img.resize((w_r,h_r)) ratio_w, ratio_h = width/resized_img.width, height/resized_img.height #print('ratios ', ratio_w, ratio_h) w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) st.image(resized_img, use_column_width='never') if choice == "Cancer Detection": state = dict() st.title(choice) st.markdown(DETECT_TEXT) uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") st.markdown("Examples can be chosen at the [GDC Data repository](https://portal.gdc.cancer.gov/repository?facetTab=cases&filters=%7B%22op%22%3A%22and%22%2C%22content%22%3A%5B%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.primary_site%22%2C%22value%22%3A%5B%22bronchus%20and%20lung%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.program.name%22%2C%22value%22%3A%5B%22TCGA%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.project_id%22%2C%22value%22%3A%5B%22TCGA-LUAD%22%2C%22TCGA-LUSC%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.experimental_strategy%22%2C%22value%22%3A%5B%22Tissue%20Slide%22%5D%7D%7D%5D%7D)") st.markdown("Alternatively, for simplicity few test cases are provided at the [drive link](https://drive.google.com/drive/folders/1u3SQa2dytZBHHh6eXTlMKY-pZGZ-pwkk?usp=share_link)") if uploaded_file is not None: # To read file as bytes: #print(uploaded_file) with open(os.path.join(uploaded_file.name),"wb") as f: f.write(uploaded_file.getbuffer()) with st.spinner(text="Computation is running"): predicted_class, viz_dict = predictor.predict(uploaded_file.name) st.info('Computation completed.') st.header(f'Predicted to be: {predicted_class}') st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected') state['cur'] = predicted_class mapper = {'ORI': predicted_class, predicted_class:'ORI'} readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' } #def fn(): # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR') # state['cur'] = mapper[state['cur']] # return #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn ) #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR') st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR' # use_column_width='never', ) if choice == "Contact": st.title(choice) st.markdown(CONTACT_TEXT) if __name__ == '__main__': #''' #load_dependencies() #''' main()