aaa_web_app / Comp2Comp-main /comp2comp /spine /spine_visualization.py
AdritRao's picture
Upload 62 files
a3290d1
"""
@author: louisblankemeier
"""
import os
from pathlib import Path
from typing import Union
import numpy as np
from comp2comp.visualization.detectron_visualizer import Visualizer
def spine_binary_segmentation_overlay(
img_in: Union[str, Path],
mask: Union[str, Path],
base_path: Union[str, Path],
file_name: str,
figure_text_key=None,
spine_hus=None,
seg_hus=None,
spine=True,
model_type=None,
pixel_spacing=None,
):
"""Save binary segmentation overlay.
Args:
img_in (Union[str, Path]): Path to the input image.
mask (Union[str, Path]): Path to the mask.
base_path (Union[str, Path]): Path to the output directory.
file_name (str): Output file name.
centroids (list, optional): List of centroids. Defaults to None.
figure_text_key (dict, optional): Figure text key. Defaults to None.
spine_hus (list, optional): List of HU values. Defaults to None.
spine (bool, optional): Spine flag. Defaults to True.
model_type (Models): Model type. Defaults to None.
"""
_COLORS = (
np.array(
[
1.000,
0.000,
0.000,
0.000,
1.000,
0.000,
1.000,
1.000,
0.000,
1.000,
0.500,
0.000,
0.000,
1.000,
1.000,
1.000,
0.000,
1.000,
]
)
.astype(np.float32)
.reshape(-1, 3)
)
label_map = {"L5": 0, "L4": 1, "L3": 2, "L2": 3, "L1": 4, "T12": 5}
_ROI_COLOR = np.array([1.000, 0.340, 0.200])
_SPINE_TEXT_OFFSET_FROM_TOP = 10.0
_SPINE_TEXT_OFFSET_FROM_RIGHT = 40.0
_SPINE_TEXT_VERTICAL_SPACING = 14.0
img_in = np.clip(img_in, -300, 1800)
img_in = normalize_img(img_in) * 255.0
images_base_path = Path(base_path) / "images"
images_base_path.mkdir(exist_ok=True)
img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1))
img_rgb = np.tile(img_in, (1, 1, 3))
vis = Visualizer(img_rgb)
levels = list(spine_hus.keys())
levels.reverse()
num_levels = len(levels)
# draw seg masks
for i, level in enumerate(levels):
color = _COLORS[label_map[level]]
edge_color = None
alpha_val = 0.2
vis.draw_binary_mask(
mask[:, :, i].astype(int),
color=color,
edge_color=edge_color,
alpha=alpha_val,
area_threshold=0,
)
# draw rois
for i, _ in enumerate(levels):
color = _ROI_COLOR
edge_color = color
vis.draw_binary_mask(
mask[:, :, num_levels + i].astype(int),
color=color,
edge_color=edge_color,
alpha=alpha_val,
area_threshold=0,
)
vis.draw_text(
text="ROI",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
_SPINE_TEXT_OFFSET_FROM_TOP,
),
color=[1, 1, 1],
font_size=9,
horizontal_alignment="center",
)
vis.draw_text(
text="Seg",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
_SPINE_TEXT_OFFSET_FROM_TOP,
),
color=[1, 1, 1],
font_size=9,
horizontal_alignment="center",
)
# draw text and lines
for i, level in enumerate(levels):
vis.draw_text(
text=f"{level}:",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 80,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="left",
)
vis.draw_text(
text=f"{round(float(spine_hus[level]))}",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="center",
)
vis.draw_text(
text=f"{round(float(seg_hus[level]))}",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="center",
)
"""
vis.draw_line(
x_data=(0, mask.shape[1] - 1),
y_data=(
int(
inferior_superior_centers[num_levels - i - 1]
* (pixel_spacing[2] / pixel_spacing[1])
),
int(
inferior_superior_centers[num_levels - i - 1]
* (pixel_spacing[2] / pixel_spacing[1])
),
),
color=_COLORS[label_map[level]],
linestyle="dashed",
linewidth=0.25,
)
"""
vis_obj = vis.get_output()
img = vis_obj.save(os.path.join(images_base_path, file_name))
return img
def normalize_img(img: np.ndarray) -> np.ndarray:
"""Normalize the image.
Args:
img (np.ndarray): Input image.
Returns:
np.ndarray: Normalized image.
"""
return (img - img.min()) / (img.max() - img.min())