sergiopaniego's picture
Updated
ca815e1
from io import BytesIO
import base64
import numpy as np
import matplotlib.pyplot as plt
import torch
def fig_to_base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.getvalue()).decode()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_boxes_on_image_base64(raw_image, boxes):
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(raw_image)
for box in boxes:
show_box(box, ax)
ax.axis('off')
return fig_to_base64(fig)
def show_points_on_image_base64(raw_image, input_points, input_labels=None):
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(raw_image)
input_points = np.array(input_points)
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
show_points(input_points, labels, ax)
ax.axis('off')
return fig_to_base64(fig)
def show_points_and_boxes_on_image_base64(raw_image, boxes, input_points, input_labels=None):
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(raw_image)
input_points = np.array(input_points)
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
show_points(input_points, labels, ax)
for box in boxes:
show_box(box, ax)
ax.axis('off')
return fig_to_base64(fig)
def show_masks_on_image_base64(raw_image, masks, scores):
if len(masks.shape) == 4:
masks = masks.squeeze()
if scores.shape[0] == 1:
scores = scores.squeeze()
nb_predictions = scores.shape[-1]
print(f"Number of predictions: {nb_predictions}")
fig, axes = plt.subplots(1, nb_predictions, figsize=(5 * nb_predictions, 5))
if nb_predictions == 1:
axes = [axes]
for i, (mask, score) in enumerate(zip(masks, scores)):
print(i)
mask = mask.cpu().detach().numpy()
axes[i].imshow(np.array(raw_image))
show_mask(mask, axes[i])
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
axes[i].axis("off")
return fig_to_base64(fig)
def show_first_mask_on_image_base64(raw_image, masks, scores):
if masks.ndim == 4:
mask = masks[0, 0]
elif masks.ndim == 3:
mask = masks[0]
else:
mask = masks
if isinstance(mask, torch.Tensor):
mask = mask.cpu().detach().numpy()
score_text = ""
if scores is not None:
if isinstance(scores, torch.Tensor):
scores = scores.flatten()
score = scores[0].item()
else:
score = float(np.array(scores).flatten()[0])
score_text = f"Score: {score:.3f}"
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(np.array(raw_image))
show_mask(mask, ax)
ax.set_title(score_text)
ax.axis("off")
return fig_to_base64(fig)
def show_all_annotations_on_image_base64(raw_image, masks=None, scores=None, boxes=None, input_points=None, input_labels=None, model_name=None):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(np.array(raw_image))
if masks is not None:
if masks.ndim == 4:
mask = masks[0, 0]
elif masks.ndim == 3:
mask = masks[0]
else:
mask = masks
if isinstance(mask, torch.Tensor):
mask = mask.cpu().detach().numpy()
show_mask(mask, ax)
if scores is not None:
if isinstance(scores, torch.Tensor):
scores = scores.flatten()
score = scores[0].item()
else:
score = float(np.array(scores).flatten()[0])
#ax.set_title(f"{model_name} - Score: {score:.3f}")
ax.set_title(f"{model_name}")
if input_points is not None:
input_points = np.array(input_points)
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
show_points(input_points, labels, ax)
if boxes is not None:
for box in boxes:
show_box(box, ax)
ax.axis("off")
return fig_to_base64(fig)