SegVol / model /utils /visualize.py
BoyaWu10's picture
init the space (#2)
a950ee6
raw
history blame contribute delete
No virus
3.22 kB
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import os
import torch
import monai.transforms as transforms
def draw_result(category, image, bboxes, points, logits, gt3D):
zoom_out_transform = transforms.Compose([
transforms.AddChanneld(keys=["image", "label", "logits"]),
transforms.Resized(keys=["image", "label", "logits"], spatial_size=(32,256,256), mode='nearest-exact')
])
print(image.shape, gt3D.shape, logits.shape)
image = image[0,0]
post_item = zoom_out_transform({
'image': image,
'label': gt3D,
'logits': logits
})
image, gt3D, logits = post_item['image'][0], post_item['label'][0], post_item['logits'][0]
preds = torch.sigmoid(logits)
preds = (preds > 0.5).int()
root_dir=os.path.join(f'./fig_examples/{category}/')
if not os.path.exists(root_dir):
os.makedirs(root_dir)
if bboxes is not None:
x1, y1, z1, x2, y2, z2 = bboxes[0].cpu().numpy()
if points is not None:
points = (points[0].cpu().numpy(), points[1].cpu().numpy())
points_ax = points[0] # [n, 3]
points_label = points[1] # [n]
# print(points_ax.shape, points_label.shape)
for j in range(image.shape[0]):
img_2d = image[j, :, :].detach().cpu().numpy()
preds_2d = preds[j, :, :].detach().cpu().numpy()
label_2d = gt3D[j, :, :].detach().cpu().numpy()
# if np.sum(label_2d) == 0 and np.sum(preds_2d) == 0:
# continue
# orginal img
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.imshow(img_2d, cmap='gray')
ax1.set_title('Image with prompt')
ax1.axis('off')
# gt
ax2.imshow(img_2d, cmap='gray')
show_mask(label_2d, ax2)
ax2.set_title('Ground truth')
ax2.axis('off')
# preds
ax3.imshow(img_2d, cmap='gray')
show_mask(preds_2d, ax3)
ax3.set_title('Prediction')
ax3.axis('off')
# boxes
if bboxes is not None:
if j >= x1 and j <= x2:
show_box((z1, y1, z2, y2), ax1)
# points
if points is not None:
for point_idx in range(points_label.shape[0]):
point = points_ax[point_idx]
label = points_label[point_idx] # [1]
if j == point[0]:
show_points(point, label, ax1)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
plt.savefig(os.path.join(root_dir, f'{category}_{j}.png'), bbox_inches='tight')
plt.close()
def show_mask(mask, ax):
color = np.array([251/255, 252/255, 30/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image, alpha=0.35)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))
def show_points(points_ax, points_label, ax):
print('draw point')
color = 'red' if points_label == 0 else 'blue'
ax.scatter(points_ax[2], points_ax[1], c=color, marker='o', s=200)