import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageDraw | |
import os | |
import torch | |
import monai.transforms as transforms | |
def draw_result(category, image, bboxes, points, logits, gt3D): | |
zoom_out_transform = transforms.Compose([ | |
transforms.AddChanneld(keys=["image", "label", "logits"]), | |
transforms.Resized(keys=["image", "label", "logits"], spatial_size=(32,256,256), mode='nearest-exact') | |
]) | |
print(image.shape, gt3D.shape, logits.shape) | |
image = image[0,0] | |
post_item = zoom_out_transform({ | |
'image': image, | |
'label': gt3D, | |
'logits': logits | |
}) | |
image, gt3D, logits = post_item['image'][0], post_item['label'][0], post_item['logits'][0] | |
preds = torch.sigmoid(logits) | |
preds = (preds > 0.5).int() | |
root_dir=os.path.join(f'./fig_examples/{category}/') | |
if not os.path.exists(root_dir): | |
os.makedirs(root_dir) | |
if bboxes is not None: | |
x1, y1, z1, x2, y2, z2 = bboxes[0].cpu().numpy() | |
if points is not None: | |
points = (points[0].cpu().numpy(), points[1].cpu().numpy()) | |
points_ax = points[0] # [n, 3] | |
points_label = points[1] # [n] | |
# print(points_ax.shape, points_label.shape) | |
for j in range(image.shape[0]): | |
img_2d = image[j, :, :].detach().cpu().numpy() | |
preds_2d = preds[j, :, :].detach().cpu().numpy() | |
label_2d = gt3D[j, :, :].detach().cpu().numpy() | |
# if np.sum(label_2d) == 0 and np.sum(preds_2d) == 0: | |
# continue | |
# orginal img | |
fig, (ax1, ax2, ax3) = plt.subplots(1, 3) | |
ax1.imshow(img_2d, cmap='gray') | |
ax1.set_title('Image with prompt') | |
ax1.axis('off') | |
# gt | |
ax2.imshow(img_2d, cmap='gray') | |
show_mask(label_2d, ax2) | |
ax2.set_title('Ground truth') | |
ax2.axis('off') | |
# preds | |
ax3.imshow(img_2d, cmap='gray') | |
show_mask(preds_2d, ax3) | |
ax3.set_title('Prediction') | |
ax3.axis('off') | |
# boxes | |
if bboxes is not None: | |
if j >= x1 and j <= x2: | |
show_box((z1, y1, z2, y2), ax1) | |
# points | |
if points is not None: | |
for point_idx in range(points_label.shape[0]): | |
point = points_ax[point_idx] | |
label = points_label[point_idx] # [1] | |
if j == point[0]: | |
show_points(point, label, ax1) | |
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) | |
plt.savefig(os.path.join(root_dir, f'{category}_{j}.png'), bbox_inches='tight') | |
plt.close() | |
def show_mask(mask, ax): | |
color = np.array([251/255, 252/255, 30/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image, alpha=0.35) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) | |
def show_points(points_ax, points_label, ax): | |
print('draw point') | |
color = 'red' if points_label == 0 else 'blue' | |
ax.scatter(points_ax[2], points_ax[1], c=color, marker='o', s=200) | |