Federico Galatolo
install detectron2 from script
e0a47a2
raw
history blame
7.58 kB
import os
import streamlit as st
try:
import detectron2
except:
with st.spinner("Installing detectron2"):
os.system("pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html")
import cv2
import sys
import argparse
import numpy as np
import json
import torch
import torch.nn.functional as F
import detectron2.data.transforms as T
import torchvision
from collections import OrderedDict
from scipy import spatial
import matplotlib.pyplot as plt
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import Metadata
from detectron2.structures.boxes import Boxes
from detectron2.structures import Instances
from plots.plot_pca_point import plot_pca_point
from plots.plot_histogram_dist import plot_histogram_dist
from plots.plot_gradcam import plot_gradcam
def extract_features(model, img, box):
height, width = img.shape[1:3]
inputs = [{"image": img, "height": height, "width": width}]
with torch.no_grad():
img = model.preprocess_image(inputs)
features = model.backbone(img.tensor)
features_ = [features[f] for f in model.roi_heads.box_in_features]
box_features = model.roi_heads.box_pooler(features_, [box])
output_features = F.avg_pool2d(box_features, [7, 7])
output_features = output_features.view(-1, 256)
return output_features
def forward_model_full(model, cfg, cv_img):
height, width = cv_img.shape[:2]
transform_gen = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
image = transform_gen.get_transform(cv_img).apply_image(cv_img)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = [{"image": image, "height": height, "width": width}]
with torch.no_grad():
images = model.preprocess_image(inputs)
features = model.backbone(images.tensor)
proposals, _ = model.proposal_generator(images, features, None)
features_ = [features[f] for f in model.roi_heads.box_in_features]
box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
box_head = model.roi_heads.box_head(box_features)
predictions = model.roi_heads.box_predictor(box_head)
output_features = F.avg_pool2d(box_features, [7, 7])
output_features = output_features.view(-1, 256)
probs = model.roi_heads.box_predictor.predict_probs(predictions, proposals)
pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances)
pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes)
instances = pred_instances[0]["instances"]
instances.set("probs", probs[0][pred_inds])
instances.set("features", output_features[pred_inds])
return instances, cv_img
def load_model():
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
cfg.MODEL.WEIGHTS = MODEL
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH
metadata = Metadata()
metadata.set(
evaluator_type="coco",
thing_classes=["neoplastic", "aphthous", "traumatic"],
thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
)
predictor = DefaultPredictor(cfg)
model = predictor.model
return dict(
predictor=predictor,
model=model,
metadata=metadata,
cfg=cfg
)
def compute_similarities(features, database):
similarities = dict()
dist_fn = getattr(spatial.distance, DISTANCE)
for file_name, elems in database.items():
for elem in elems:
similarities[file_name] = dict(
dist=dist_fn(elem["features"], features),
file_name=file_name,
box=elem["roi"],
type=elem["type"]
)
similarities = OrderedDict(sorted(similarities.items(), key=lambda e: e[1]["dist"]))
return similarities
def draw_box(file_name, box, type, model, resize_input=False):
height, width, channels = img.shape
pred_v = Visualizer(img[:, :, ::-1], model["metadata"], scale=1)
instances = Instances((height, width), pred_boxes=Boxes(torch.tensor(box).unsqueeze(0)), pred_classes=torch.tensor([type]))
pred_v = pred_v.draw_instance_predictions(instances)
pred = pred_v.get_image()[:, :, ::-1]
pred = cv2.resize(pred, (800, 800))
return pred
def explain(img, model):
database = json.load(open(FEATURES_DATABASE))
instances, input = forward_model_full(model["model"], model["cfg"], img)
instances.remove("pred_masks")
pred_v = Visualizer(cv2.cvtColor(input, cv2.COLOR_BGR2RGB), model["metadata"], scale=1)
pred_v = pred_v.draw_instance_predictions(instances.to("cpu"))
pred = pred_v.get_image()[:, :, ::-1]
pred = cv2.resize(pred, (800, 800))
pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
tabs = st.tabs(["Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))])
lesion_tabs = tabs[1:]
with tabs[0]:
st.header("Detected lesions")
state.text("All done...")
tooltip.success("Use the tabs for a detailed explanation of each lesion")
st.image(pred)
for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)):
healthy_prob = scores[-1].item()
scores = scores[:-1]
features = features.tolist()
with tab:
st.header(f"Lesion #{i}")
lesion_img = draw_box(img, box.cpu(), type, model)
lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB)
classes = ["healty", "neoplastic", "aphthous", "traumatic"]
y_pos = np.arange(len(classes))
probs = [healthy_prob] + scores.cpu().numpy().tolist()
probs_fig = plt.figure()
plt.bar(y_pos, probs, align="center")
plt.xticks(y_pos, classes)
plt.ylabel("Probability")
plt.title("Class")
st.subheader("Classification")
col1, col2 = st.columns(2)
col1.image(lesion_img)
col2.pyplot(probs_fig)
st.subheader("Feature space")
col1, col2 = st.columns(2)
fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100)
col1.pyplot(fig)
fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100)
col2.pyplot(fig)
st.subheader("Gradcam++")
fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
st.pyplot(fig)
FILE = "./test.jpg"
MODEL = "./models/model.pth"
PCA_MODEL = "./models/pca.pkl"
FEATURES_DATABASE = "./assets/features/features.json"
DISTANCE = "cosine"
TH = 0.5
state = st.empty()
tooltip = st.empty()
state.write("Loading model...")
model = load_model()
img = cv2.imread(FILE)
img = cv2.resize(img, (800, 800))
explain(img, model)