# utils/image_utils.py
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import nibabel as nib
import base64
from io import BytesIO
default_orientation_type = 'transpose'
default_plt_origin_type = 'upper'
def image_to_base64(img, width=200):
buffered = BytesIO()
if img.mode == "RGBA":
img.save(buffered, format="PNG")
result = base64.b64encode(buffered.getvalue()).decode()
st.markdown(
f"""
""",
unsafe_allow_html=True
)
else:
img.save(buffered, format="JPEG")
result = base64.b64encode(buffered.getvalue()).decode()
st.markdown(
f"""
""",
unsafe_allow_html=True
)
def processing_slice_to_right_orientation(img_slice, type=default_orientation_type):
if type == 'transpose':
return img_slice.T
elif type == 'rot90':
return np.rot90(img_slice)
elif type == 'none':
return img_slice
def restore_slice_to_wrong_orientation(img_slice, type=default_orientation_type):
if type == 'transpose':
return img_slice.T
elif type == 'rot90':
return np.rot90(img_slice,3)
elif type == 'none':
return img_slice
def load_image_canonical(nii_file):
img = nib.load(nii_file)
#img_canonical = nib.as_closest_canonical(img)
data = img.get_fdata()
return img
def get_compatible_cmap(name="tab20", N=20):
# 优先使用 plt.get_cmap(),如果没有 fallback 到 cm.get_cmap()
try:
return plt.get_cmap(name, N)
except TypeError:
# for older versions of matplotlib
return cm.get_cmap(name, N)
def generate_color_map(label_ids, cmap='tab20'):
cmap = get_compatible_cmap(cmap, len(label_ids)) # or 'Set3', 'tab10'
color_map = {}
for i, label_id in enumerate(label_ids):
rgba = cmap(i)
rgb = tuple(int(255 * c) for c in rgba[:3])
color_map[label_id] = ",".join(map(str, rgb))
return color_map
# utils/image_utils.py
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
def global_slice_slider(volume_shape):
st.markdown("### 🔎 Global slice controller")
col_z, col_y, col_x = st.columns(3)
with col_z:
z_idx = st.slider("Axial (Z)", 0, volume_shape[2]-1, volume_shape[2] // 2, key="z_slider")
with col_y:
y_idx = st.slider("Coronal (Y)", 0, volume_shape[1]-1, volume_shape[1] // 2, key="y_slider")
with col_x:
x_idx = st.slider("Sagittal (X)", 0, volume_shape[0]-1, volume_shape[0] // 2, key="x_slider")
return z_idx, y_idx, x_idx
from PIL import Image
def show_single_slice_label(label2d, label_colors, title="Label Slice"):
"""
显示单张 2D 标签图像(使用 RGB 映射)。
label_colors: dict[int -> str],如 {1: "255,0,0"}
"""
import matplotlib.pyplot as plt
import io
rgb_map = np.zeros((*label2d.shape, 3), dtype=np.uint8)
for label, rgb_str in label_colors.items():
rgb_vals = [int(v) for v in rgb_str.split(",")]
mask = label2d == label
mask = processing_slice_to_right_orientation(mask)
for c in range(3):
rgb_map[:, :, c][mask] = rgb_vals[c]
st.image(Image.fromarray(rgb_map), use_container_width =True)
def show_single_slice_image(image2d, title="Slice", orientation_type=default_orientation_type):
"""
用 Streamlit 原生方式显示灰度图(不经过 matplotlib)。
"""
import numpy as np
# normalize to [0, 255]
img = image2d.astype(np.float32)
img = np.nan_to_num(img)
img_min, img_max = img.min(), img.max()
if img_max > img_min:
img = (img - img_min) / (img_max - img_min)
img = processing_slice_to_right_orientation(img, orientation_type)
img_uint8 = (img * 255).astype(np.uint8)
'''fig, ax = plt.subplots()
ax.imshow(rgb_map, cmap='gray', origin=default_plt_origin_type)
ax.axis('off')
st.pyplot(fig)'''
st.image(img_uint8, caption=title, use_container_width =True, clamp=True)
def show_single_planes_interactive(image, z_idx, y_idx, x_idx, orientation_type=default_orientation_type):
"""
Show three orthogonal planes simultaneously with slice sliders.
Supports nibabel image object or raw NumPy array.
"""
if hasattr(image, "get_fdata"):
data = image.get_fdata()
else:
data = image
data = np.nan_to_num(data).astype(np.float32)
fig, ax = plt.subplots()
ax.imshow(processing_slice_to_right_orientation(data[:, :, z_idx],orientation_type), cmap='gray', origin=default_plt_origin_type)
ax.axis('off')
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches='tight', pad_inches=0)
buf.seek(0)
st.image(buf)
#st.pyplot(fig)
def show_three_planes_interactive(image, z_idx, y_idx, x_idx, orientation_type=default_orientation_type):
"""
Show three orthogonal planes simultaneously with slice sliders.
Supports nibabel image object or raw NumPy array.
"""
if hasattr(image, "get_fdata"):
data = image.get_fdata()
else:
data = image
data = np.nan_to_num(data).astype(np.float32)
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(processing_slice_to_right_orientation(data[:, :, z_idx],orientation_type), cmap='gray', origin=default_plt_origin_type)
axs[0].set_title(f"Axial @ {z_idx}")
axs[1].imshow(processing_slice_to_right_orientation(data[:, y_idx, :],orientation_type), cmap='gray', origin=default_plt_origin_type)
axs[1].set_title(f"Coronal @ {y_idx}")
axs[2].imshow(processing_slice_to_right_orientation(data[x_idx, :, :],orientation_type), cmap='gray', origin=default_plt_origin_type)
axs[2].set_title(f"Sagittal @ {x_idx}")
for ax in axs:
ax.axis('off')
#st.pyplot(fig)
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches='tight', pad_inches=0)
buf.seek(0)
st.image(buf)
def show_label_overlay_single(label_volume, z_idx, y_idx, x_idx, label_colors=None):
"""
Show label slices in three orthogonal planes with sliders and color overlays.
label_colors: dict[int -> str] with RGB strings like "255,0,0"
"""
if hasattr(label_volume, "get_fdata"):
label_data = label_volume.get_fdata().astype(np.int32)
else:
label_data = label_volume.astype(np.int32)
z_max = label_data.shape[2] - 1
y_max = label_data.shape[1] - 1
x_max = label_data.shape[0] - 1
def label_to_rgb(slice_data):
if not label_colors:
return slice_data
rgb_map = np.zeros((*slice_data.shape, 3), dtype=np.uint8)
for label, rgb_str in label_colors.items():
if isinstance(rgb_str, str):
rgb_vals = [int(v) for v in rgb_str.split(",")]
else:
rgb_vals = [0, 0, 0]
mask = slice_data == label
for c in range(3):
rgb_map[:, :, c][mask] = rgb_vals[c]
return rgb_map
axial = label_to_rgb(processing_slice_to_right_orientation(label_data[:, :, z_idx]))
coronal = label_to_rgb(processing_slice_to_right_orientation(label_data[:, y_idx, :]))
sagittal = label_to_rgb(processing_slice_to_right_orientation(label_data[x_idx, :, :]))
fig, ax = plt.subplots()
ax.imshow(axial, origin=default_plt_origin_type)
ax.axis('off')
# st.pyplot(fig)
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches='tight', pad_inches=0)
buf.seek(0)
st.image(buf)
def show_label_overlay(label_volume, z_idx, y_idx, x_idx, label_colors=None):
"""
Show label slices in three orthogonal planes with sliders and color overlays.
label_colors: dict[int -> str] with RGB strings like "255,0,0"
"""
if hasattr(label_volume, "get_fdata"):
label_data = label_volume.get_fdata().astype(np.int32)
else:
label_data = label_volume.astype(np.int32)
z_max = label_data.shape[2] - 1
y_max = label_data.shape[1] - 1
x_max = label_data.shape[0] - 1
def label_to_rgb(slice_data):
if not label_colors:
return slice_data
rgb_map = np.zeros((*slice_data.shape, 3), dtype=np.uint8)
for label, rgb_str in label_colors.items():
if isinstance(rgb_str, str):
rgb_vals = [int(v) for v in rgb_str.split(",")]
else:
rgb_vals = [0, 0, 0]
mask = slice_data == label
for c in range(3):
rgb_map[:, :, c][mask] = rgb_vals[c]
return rgb_map
axial = label_to_rgb(processing_slice_to_right_orientation(label_data[:, :, z_idx]))
coronal = label_to_rgb(processing_slice_to_right_orientation(label_data[:, y_idx, :]))
sagittal = label_to_rgb(processing_slice_to_right_orientation(label_data[x_idx, :, :]))
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(axial, origin=default_plt_origin_type)
axs[0].set_title(f"Axial @ {z_idx}")
axs[1].imshow(coronal, origin=default_plt_origin_type)
axs[1].set_title(f"Coronal @ {y_idx}")
axs[2].imshow(sagittal, origin=default_plt_origin_type)
axs[2].set_title(f"Sagittal @ {x_idx}")
for ax in axs:
ax.axis('off')
#st.pyplot(fig)
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches='tight', pad_inches=0)
buf.seek(0)
st.image(buf)
def show_three_planes(image, title_prefix=""):
if hasattr(image, "get_fdata"):
data = image.get_fdata()
else:
data = image
data = np.nan_to_num(data).astype(np.float32)
mid_axial = data.shape[2] // 2
mid_coronal = data.shape[1] // 2
mid_sagittal = data.shape[0] // 2
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(processing_slice_to_right_orientation(data[:, :, mid_axial]), cmap='gray', origin=default_plt_origin_type)
axs[0].set_title(f'{title_prefix} Axial')
axs[1].imshow(processing_slice_to_right_orientation(data[:, mid_coronal, :]), cmap='gray', origin=default_plt_origin_type)
axs[1].set_title(f'{title_prefix} Coronal')
axs[2].imshow(processing_slice_to_right_orientation(data[mid_sagittal, :, :]), cmap='gray', origin=default_plt_origin_type)
axs[2].set_title(f'{title_prefix} Sagittal')
for ax in axs:
ax.axis('off')
st.pyplot(fig)