Spaces:
Sleeping
Sleeping
File size: 2,448 Bytes
bc679dd 0a91a75 bc679dd fa81659 bc679dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import argparse
import torch
import matplotlib
import matplotlib.pyplot as plt
from types import SimpleNamespace
from detectron2.utils.visualizer import Visualizer
from detectron2.data import Metadata
from detectron2 import model_zoo
from plots.gradcam.detectron2_gradcam import Detectron2GradCAM
def plot_gradcam(**kwargs):
kwargs = SimpleNamespace(**kwargs)
config_file = model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg_list = [
"MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(kwargs.th),
"MODEL.ROI_HEADS.NUM_CLASSES", "3",
"MODEL.WEIGHTS", kwargs.model,
"MODEL.DEVICE", "cpu"
]
metadata = Metadata()
metadata.set(
evaluator_type="coco",
thing_classes=["neoplastic", "aphthous", "traumatic"],
thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
)
cam_extractor = Detectron2GradCAM(config_file, cfg_list)
image_dict, cam_orig = cam_extractor.get_cam(img=kwargs.img, target_instance=kwargs.instance, layer_name=kwargs.layer, grad_cam_type="GradCAM++")
with torch.no_grad():
fig = plt.figure(figsize=(kwargs.fig_h/kwargs.fig_dpi, kwargs.fig_w/kwargs.fig_dpi), dpi=kwargs.fig_dpi)
v = Visualizer(image_dict["image"], metadata, scale=1.0)
img = image_dict["output"]["instances"][kwargs.instance]
img.remove("pred_masks")
out = v.draw_instance_predictions(img.to("cpu"))
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
hspace = 0, wspace = 0)
plt.margins(0,0)
plt.imshow(out.get_image(), interpolation='none')
plt.imshow(image_dict["cam"], cmap='jet', alpha=0.5)
return fig
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--layer", type=str, default="backbone.bottom_up.res5.2.conv3")
parser.add_argument("--th", type=float, default=0.5)
parser.add_argument("--file", type=str, required=True)
parser.add_argument("--instance", type=int, required=True)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--fig-h", type=int, default=1080)
parser.add_argument("--fig-w", type=int, default=720)
parser.add_argument("--fig-dpi", type=int, default=100)
args = parser.parse_args()
plot_gradcam(**vars(args)) |