|
from io import BytesIO |
|
import base64 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
|
|
def fig_to_base64(fig): |
|
buf = BytesIO() |
|
fig.savefig(buf, format='png', bbox_inches='tight') |
|
plt.close(fig) |
|
buf.seek(0) |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def show_box(box, ax): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) |
|
|
|
def show_points(coords, labels, ax, marker_size=375): |
|
pos_points = coords[labels==1] |
|
neg_points = coords[labels==0] |
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
|
|
def show_boxes_on_image_base64(raw_image, boxes): |
|
fig, ax = plt.subplots(figsize=(10,10)) |
|
ax.imshow(raw_image) |
|
for box in boxes: |
|
show_box(box, ax) |
|
ax.axis('off') |
|
return fig_to_base64(fig) |
|
|
|
def show_points_on_image_base64(raw_image, input_points, input_labels=None): |
|
fig, ax = plt.subplots(figsize=(10,10)) |
|
ax.imshow(raw_image) |
|
input_points = np.array(input_points) |
|
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) |
|
show_points(input_points, labels, ax) |
|
ax.axis('off') |
|
return fig_to_base64(fig) |
|
|
|
def show_points_and_boxes_on_image_base64(raw_image, boxes, input_points, input_labels=None): |
|
fig, ax = plt.subplots(figsize=(10,10)) |
|
ax.imshow(raw_image) |
|
input_points = np.array(input_points) |
|
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) |
|
show_points(input_points, labels, ax) |
|
for box in boxes: |
|
show_box(box, ax) |
|
ax.axis('off') |
|
return fig_to_base64(fig) |
|
|
|
def show_masks_on_image_base64(raw_image, masks, scores): |
|
if len(masks.shape) == 4: |
|
masks = masks.squeeze() |
|
if scores.shape[0] == 1: |
|
scores = scores.squeeze() |
|
|
|
nb_predictions = scores.shape[-1] |
|
print(f"Number of predictions: {nb_predictions}") |
|
fig, axes = plt.subplots(1, nb_predictions, figsize=(5 * nb_predictions, 5)) |
|
|
|
if nb_predictions == 1: |
|
axes = [axes] |
|
|
|
for i, (mask, score) in enumerate(zip(masks, scores)): |
|
print(i) |
|
mask = mask.cpu().detach().numpy() |
|
axes[i].imshow(np.array(raw_image)) |
|
show_mask(mask, axes[i]) |
|
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}") |
|
axes[i].axis("off") |
|
|
|
return fig_to_base64(fig) |
|
|
|
def show_first_mask_on_image_base64(raw_image, masks, scores): |
|
if masks.ndim == 4: |
|
mask = masks[0, 0] |
|
elif masks.ndim == 3: |
|
mask = masks[0] |
|
else: |
|
mask = masks |
|
|
|
if isinstance(mask, torch.Tensor): |
|
mask = mask.cpu().detach().numpy() |
|
|
|
score_text = "" |
|
if scores is not None: |
|
if isinstance(scores, torch.Tensor): |
|
scores = scores.flatten() |
|
score = scores[0].item() |
|
else: |
|
score = float(np.array(scores).flatten()[0]) |
|
score_text = f"Score: {score:.3f}" |
|
|
|
fig, ax = plt.subplots(figsize=(5, 5)) |
|
ax.imshow(np.array(raw_image)) |
|
show_mask(mask, ax) |
|
ax.set_title(score_text) |
|
ax.axis("off") |
|
|
|
return fig_to_base64(fig) |
|
|
|
def show_all_annotations_on_image_base64(raw_image, masks=None, scores=None, boxes=None, input_points=None, input_labels=None, model_name=None): |
|
fig, ax = plt.subplots(figsize=(10, 10)) |
|
ax.imshow(np.array(raw_image)) |
|
|
|
if masks is not None: |
|
if masks.ndim == 4: |
|
mask = masks[0, 0] |
|
elif masks.ndim == 3: |
|
mask = masks[0] |
|
else: |
|
mask = masks |
|
if isinstance(mask, torch.Tensor): |
|
mask = mask.cpu().detach().numpy() |
|
show_mask(mask, ax) |
|
|
|
if scores is not None: |
|
if isinstance(scores, torch.Tensor): |
|
scores = scores.flatten() |
|
score = scores[0].item() |
|
else: |
|
score = float(np.array(scores).flatten()[0]) |
|
|
|
ax.set_title(f"{model_name}") |
|
|
|
|
|
if input_points is not None: |
|
input_points = np.array(input_points) |
|
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) |
|
show_points(input_points, labels, ax) |
|
|
|
if boxes is not None: |
|
for box in boxes: |
|
show_box(box, ax) |
|
|
|
ax.axis("off") |
|
return fig_to_base64(fig) |
|
|