import os import streamlit as st import cv2 import sys import argparse import numpy as np import json import torch import torch.nn.functional as F import detectron2.data.transforms as T import torchvision from collections import OrderedDict from scipy import spatial import matplotlib.pyplot as plt from packaging import version from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import Visualizer from detectron2.config import get_cfg from detectron2 import model_zoo from detectron2.data import Metadata from detectron2.structures.boxes import Boxes from detectron2.structures import Instances from plots.plot_pca_point import plot_pca_point from plots.plot_histogram_dist import plot_histogram_dist from plots.plot_gradcam import plot_gradcam def extract_features(model, img, box): height, width = img.shape[1:3] inputs = [{"image": img, "height": height, "width": width}] with torch.no_grad(): img = model.preprocess_image(inputs) features = model.backbone(img.tensor) features_ = [features[f] for f in model.roi_heads.box_in_features] box_features = model.roi_heads.box_pooler(features_, [box]) output_features = F.avg_pool2d(box_features, [7, 7]) output_features = output_features.view(-1, 256) return output_features def forward_model_full(model, cfg, cv_img): height, width = cv_img.shape[:2] transform_gen = T.ResizeShortestEdge( [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST ) image = transform_gen.get_transform(cv_img).apply_image(cv_img) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) inputs = [{"image": image, "height": height, "width": width}] with torch.no_grad(): images = model.preprocess_image(inputs) features = model.backbone(images.tensor) proposals, _ = model.proposal_generator(images, features, None) features_ = [features[f] for f in model.roi_heads.box_in_features] box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals]) box_head = model.roi_heads.box_head(box_features) predictions = model.roi_heads.box_predictor(box_head) output_features = F.avg_pool2d(box_features, [7, 7]) output_features = output_features.view(-1, 256) probs = model.roi_heads.box_predictor.predict_probs(predictions, proposals) pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals) pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances) pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes) instances = pred_instances[0]["instances"] instances.set("probs", probs[0][pred_inds]) instances.set("features", output_features[pred_inds]) return instances, cv_img def load_model(): cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 cfg.MODEL.WEIGHTS = MODEL cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH cfg.MODEL.DEVICE = "cpu" metadata = Metadata() metadata.set( evaluator_type="coco", thing_classes=["neoplastic", "aphthous", "traumatic"], thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2} ) predictor = DefaultPredictor(cfg) model = predictor.model return dict( predictor=predictor, model=model, metadata=metadata, cfg=cfg ) def draw_box(file_name, box, type, model, resize_input=False): height, width, channels = img.shape pred_v = Visualizer(img[:, :, ::-1], model["metadata"], scale=1) instances = Instances((height, width), pred_boxes=Boxes(torch.tensor(box).unsqueeze(0)), pred_classes=torch.tensor([type])) pred_v = pred_v.draw_instance_predictions(instances) pred = pred_v.get_image()[:, :, ::-1] pred = cv2.resize(pred, (800, 800)) return pred def explain(img, model): state.write("Loading features...") database = json.load(open(FEATURES_DATABASE)) state.write("Computing logits...") instances, input = forward_model_full(model["model"], model["cfg"], img) instances.remove("pred_masks") pred_v = Visualizer(cv2.cvtColor(input, cv2.COLOR_BGR2RGB), model["metadata"], scale=1) pred_v = pred_v.draw_instance_predictions(instances.to("cpu")) pred = pred_v.get_image()[:, :, ::-1] pred = cv2.resize(pred, (800, 800)) pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB) if version.parse(st.__version__) >= version.parse("1.11.0"): tabs = st.tabs(["Result", "Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))]) lesion_tabs = tabs[2:] detection_tab = tabs[1] with tabs[0]: st.header("Image processed") st.success("Use the tabs on the right to see the detected lesions and detailed explanations for each lesion") else: tabs = [st.container() for i in range(0, len(instances)+1)] lesion_tabs = tabs[1:] detection_tab = tabs[0] state.write("Populating first tab...") with detection_tab: st.header("Detected lesions") st.image(pred) for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)): state.write(f"Populating tab for lesion #{i}...") healthy_prob = scores[-1].item() scores = scores[:-1] features = features.tolist() with tab: st.header(f"Lesion #{i}") state.write(f"Populating classes for lesion #{i}...") lesion_img = draw_box(img, box.cpu(), type, model) lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB) classes = ["healty", "neoplastic", "aphthous", "traumatic"] y_pos = np.arange(len(classes)) probs = [healthy_prob] + scores.cpu().numpy().tolist() probs_fig = plt.figure() plt.bar(y_pos, probs, align="center") plt.xticks(y_pos, classes) plt.ylabel("Probability") plt.title("Class") st.subheader("Classification") col1, col2 = st.columns(2) col1.image(lesion_img) col2.pyplot(probs_fig) st.subheader("Feature space") col1, col2 = st.columns(2) state.write(f"Populating PCA for lesion #{i}...") fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100) col1.pyplot(fig) state.write(f"Populating histogram for lesion #{i}...") fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100) col2.pyplot(fig) state.write(f"Populating Gradcam++ for lesion #{i}...") st.subheader("Gradcam++") fig = plot_gradcam(model=MODEL, img=img, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3") st.pyplot(fig) state.write("All done...") FILE = "./test.jpg" MODEL = "./models/model.pth" PCA_MODEL = "./models/pca.pkl" FEATURES_DATABASE = "./assets/features/features.json" st.header("Explainable Oral Lesion Detection") st.markdown("""Demo for the paper [Explainable diagnosis of oral cancer via deep learning and case-based reasoning](https://mlpi.ing.unipi.it/doctoralai/) Upload an image using the form below and click on "Process" """) FILE = st.file_uploader("Image", type=["jpg", "jpeg", "png"]) TH = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.5) process = st.button("Process") state = st.empty() if process: state.write("Loading model...") model = load_model() nparr = np.fromstring(FILE.getvalue(), np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) #img = cv2.imread(FILE) img = cv2.resize(img, (800, 800)) explain(img, model)