Spaces:
Build error
Build error
| """ | |
| Mask R-CNN | |
| Display and Visualization Functions. | |
| Copyright (c) 2017 Matterport, Inc. | |
| Licensed under the MIT License (see LICENSE for details) | |
| Written by Waleed Abdulla | |
| """ | |
| import colorsys | |
| import itertools | |
| import os | |
| import random | |
| import sys | |
| import IPython.display | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib import lines | |
| from matplotlib import patches | |
| from matplotlib.patches import Polygon | |
| from skimage.measure import find_contours | |
| # Root directory of the project | |
| ROOT_DIR = os.path.abspath("../") | |
| # Import Mask RCNN | |
| sys.path.append(ROOT_DIR) # To find local version of the library | |
| from mrcnn import utils | |
| ############################################################ | |
| # Visualization | |
| ############################################################ | |
| def display_images( | |
| images, titles=None, cols=4, cmap=None, norm=None, interpolation=None | |
| ): | |
| """Display the given set of images, optionally with titles. | |
| images: list or array of image tensors in HWC format. | |
| titles: optional. A list of titles to display with each image. | |
| cols: number of images per row | |
| cmap: Optional. Color map to use. For example, "Blues". | |
| norm: Optional. A Normalize instance to map values to colors. | |
| interpolation: Optional. Image interpolation to use for display. | |
| """ | |
| titles = titles if titles is not None else [""] * len(images) | |
| rows = len(images) // cols + 1 | |
| plt.figure(figsize=(14, 14 * rows // cols)) | |
| i = 1 | |
| for image, title in zip(images, titles): | |
| plt.subplot(rows, cols, i) | |
| plt.title(title, fontsize=9) | |
| plt.axis("off") | |
| plt.imshow( | |
| image.astype(np.uint8), cmap=cmap, norm=norm, interpolation=interpolation | |
| ) | |
| i += 1 | |
| plt.show() | |
| def random_colors(N, bright=True): | |
| """ | |
| Generate random colors. | |
| To get visually distinct colors, generate them in HSV space then | |
| convert to RGB. | |
| """ | |
| brightness = 1.0 if bright else 0.7 | |
| hsv = [(i / N, 1, brightness) for i in range(N)] | |
| colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) | |
| random.shuffle(colors) | |
| return colors | |
| def apply_mask(image, mask, color, alpha=0.5): | |
| """Apply the given mask to the image.""" | |
| for c in range(3): | |
| image[:, :, c] = np.where( | |
| mask == 1, | |
| image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, | |
| image[:, :, c], | |
| ) | |
| return image | |
| def display_instances( | |
| image, | |
| boxes, | |
| masks, | |
| class_ids, | |
| class_names, | |
| scores=None, | |
| title="", | |
| figsize=(16, 16), | |
| ax=None, | |
| show_mask=True, | |
| show_bbox=True, | |
| colors=None, | |
| captions=None, | |
| ): | |
| """ | |
| boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. | |
| masks: [height, width, num_instances] | |
| class_ids: [num_instances] | |
| class_names: list of class names of the dataset | |
| scores: (optional) confidence scores for each box | |
| title: (optional) Figure title | |
| show_mask, show_bbox: To show masks and bounding boxes or not | |
| figsize: (optional) the size of the image | |
| colors: (optional) An array or colors to use with each object | |
| captions: (optional) A list of strings to use as captions for each object | |
| """ | |
| # Number of instances | |
| N = boxes.shape[0] | |
| if not N: | |
| print("\n*** No instances to display *** \n") | |
| else: | |
| assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] | |
| # If no axis is passed, create one and automatically call show() | |
| auto_show = False | |
| if not ax: | |
| _, ax = plt.subplots(1, figsize=figsize) | |
| auto_show = True | |
| # Generate random colors | |
| colors = colors or random_colors(N) | |
| # Show area outside image boundaries. | |
| height, width = image.shape[:2] | |
| ax.set_ylim(height + 10, -10) | |
| ax.set_xlim(-10, width + 10) | |
| ax.axis("off") | |
| ax.set_title(title) | |
| masked_image = image.astype(np.uint32).copy() | |
| for i in range(N): | |
| color = colors[i] | |
| # Bounding box | |
| if not np.any(boxes[i]): | |
| # Skip this instance. Has no bbox. Likely lost in image cropping. | |
| continue | |
| y1, x1, y2, x2 = boxes[i] | |
| if show_bbox: | |
| p = patches.Rectangle( | |
| (x1, y1), | |
| x2 - x1, | |
| y2 - y1, | |
| linewidth=2, | |
| alpha=0.7, | |
| linestyle="dashed", | |
| edgecolor=color, | |
| facecolor="none", | |
| ) | |
| ax.add_patch(p) | |
| # Label | |
| if not captions: | |
| class_id = class_ids[i] | |
| score = scores[i] if scores is not None else None | |
| label = class_names[class_id] | |
| caption = "{} {:.3f}".format(label, score) if score else label | |
| else: | |
| caption = captions[i] | |
| ax.text(x1, y1 + 8, caption, color="w", size=11, backgroundcolor="none") | |
| # Mask | |
| mask = masks[:, :, i] | |
| if show_mask: | |
| masked_image = apply_mask(masked_image, mask, color) | |
| # Mask Polygon | |
| # Pad to ensure proper polygons for masks that touch image edges. | |
| padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) | |
| padded_mask[1:-1, 1:-1] = mask | |
| contours = find_contours(padded_mask, 0.5) | |
| for verts in contours: | |
| # Subtract the padding and flip (y, x) to (x, y) | |
| verts = np.fliplr(verts) - 1 | |
| p = Polygon(verts, facecolor="none", edgecolor=color) | |
| ax.add_patch(p) | |
| # ax.imshow(masked_image.astype(np.uint8)) | |
| if auto_show: | |
| plt.show() | |
| return masked_image.astype(np.uint8) | |
| def display_differences( | |
| image, | |
| gt_box, | |
| gt_class_id, | |
| gt_mask, | |
| pred_box, | |
| pred_class_id, | |
| pred_score, | |
| pred_mask, | |
| class_names, | |
| title="", | |
| ax=None, | |
| show_mask=True, | |
| show_box=True, | |
| iou_threshold=0.5, | |
| score_threshold=0.5, | |
| ): | |
| """Display ground truth and prediction instances on the same image.""" | |
| # Match predictions to ground truth | |
| gt_match, pred_match, overlaps = utils.compute_matches( | |
| gt_box, | |
| gt_class_id, | |
| gt_mask, | |
| pred_box, | |
| pred_class_id, | |
| pred_score, | |
| pred_mask, | |
| iou_threshold=iou_threshold, | |
| score_threshold=score_threshold, | |
| ) | |
| # Ground truth = green. Predictions = red | |
| colors = [(0, 1, 0, 0.8)] * len(gt_match) + [(1, 0, 0, 1)] * len(pred_match) | |
| # Concatenate GT and predictions | |
| class_ids = np.concatenate([gt_class_id, pred_class_id]) | |
| scores = np.concatenate([np.zeros([len(gt_match)]), pred_score]) | |
| boxes = np.concatenate([gt_box, pred_box]) | |
| masks = np.concatenate([gt_mask, pred_mask], axis=-1) | |
| # Captions per instance show score/IoU | |
| captions = ["" for m in gt_match] + [ | |
| "{:.2f} / {:.2f}".format( | |
| pred_score[i], | |
| ( | |
| overlaps[i, int(pred_match[i])] | |
| if pred_match[i] > -1 | |
| else overlaps[i].max() | |
| ), | |
| ) | |
| for i in range(len(pred_match)) | |
| ] | |
| # Set title if not provided | |
| title = ( | |
| title or "Ground Truth and Detections\n GT=green, pred=red, captions: score/IoU" | |
| ) | |
| # Display | |
| display_instances( | |
| image, | |
| boxes, | |
| masks, | |
| class_ids, | |
| class_names, | |
| scores, | |
| ax=ax, | |
| show_bbox=show_box, | |
| show_mask=show_mask, | |
| colors=colors, | |
| captions=captions, | |
| title=title, | |
| ) | |
| def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10): | |
| """ | |
| anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates. | |
| proposals: [n, 4] the same anchors but refined to fit objects better. | |
| """ | |
| masked_image = image.copy() | |
| # Pick random anchors in case there are too many. | |
| ids = np.arange(rois.shape[0], dtype=np.int32) | |
| ids = np.random.choice(ids, limit, replace=False) if ids.shape[0] > limit else ids | |
| fig, ax = plt.subplots(1, figsize=(12, 12)) | |
| if rois.shape[0] > limit: | |
| plt.title("Showing {} random ROIs out of {}".format(len(ids), rois.shape[0])) | |
| else: | |
| plt.title("{} ROIs".format(len(ids))) | |
| # Show area outside image boundaries. | |
| ax.set_ylim(image.shape[0] + 20, -20) | |
| ax.set_xlim(-50, image.shape[1] + 20) | |
| ax.axis("off") | |
| for i, id in enumerate(ids): | |
| color = np.random.rand(3) | |
| class_id = class_ids[id] | |
| # ROI | |
| y1, x1, y2, x2 = rois[id] | |
| p = patches.Rectangle( | |
| (x1, y1), | |
| x2 - x1, | |
| y2 - y1, | |
| linewidth=2, | |
| edgecolor=color if class_id else "gray", | |
| facecolor="none", | |
| linestyle="dashed", | |
| ) | |
| ax.add_patch(p) | |
| # Refined ROI | |
| if class_id: | |
| ry1, rx1, ry2, rx2 = refined_rois[id] | |
| p = patches.Rectangle( | |
| (rx1, ry1), | |
| rx2 - rx1, | |
| ry2 - ry1, | |
| linewidth=2, | |
| edgecolor=color, | |
| facecolor="none", | |
| ) | |
| ax.add_patch(p) | |
| # Connect the top-left corners of the anchor and proposal for easy visualization | |
| ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) | |
| # Label | |
| label = class_names[class_id] | |
| ax.text( | |
| rx1, | |
| ry1 + 8, | |
| "{}".format(label), | |
| color="w", | |
| size=11, | |
| backgroundcolor="none", | |
| ) | |
| # Mask | |
| m = utils.unmold_mask(mask[id], rois[id][:4].astype(np.int32), image.shape) | |
| masked_image = apply_mask(masked_image, m, color) | |
| ax.imshow(masked_image) | |
| # Print stats | |
| print("Positive ROIs: ", class_ids[class_ids > 0].shape[0]) | |
| print("Negative ROIs: ", class_ids[class_ids == 0].shape[0]) | |
| print( | |
| "Positive Ratio: {:.2f}".format( | |
| class_ids[class_ids > 0].shape[0] / class_ids.shape[0] | |
| ) | |
| ) | |
| # TODO: Replace with matplotlib equivalent? | |
| def draw_box(image, box, color): | |
| """Draw 3-pixel width bounding boxes on the given image array. | |
| color: list of 3 int values for RGB. | |
| """ | |
| y1, x1, y2, x2 = box | |
| image[y1 : y1 + 2, x1:x2] = color | |
| image[y2 : y2 + 2, x1:x2] = color | |
| image[y1:y2, x1 : x1 + 2] = color | |
| image[y1:y2, x2 : x2 + 2] = color | |
| return image | |
| def display_top_masks(image, mask, class_ids, class_names, limit=4): | |
| """Display the given image and the top few class masks.""" | |
| to_display = [] | |
| titles = [] | |
| to_display.append(image) | |
| titles.append("H x W={}x{}".format(image.shape[0], image.shape[1])) | |
| # Pick top prominent classes in this image | |
| unique_class_ids = np.unique(class_ids) | |
| mask_area = [ | |
| np.sum(mask[:, :, np.where(class_ids == i)[0]]) for i in unique_class_ids | |
| ] | |
| top_ids = [ | |
| v[0] | |
| for v in sorted( | |
| zip(unique_class_ids, mask_area), key=lambda r: r[1], reverse=True | |
| ) | |
| if v[1] > 0 | |
| ] | |
| # Generate images and titles | |
| for i in range(limit): | |
| class_id = top_ids[i] if i < len(top_ids) else -1 | |
| # Pull masks of instances belonging to the same class. | |
| m = mask[:, :, np.where(class_ids == class_id)[0]] | |
| m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1) | |
| to_display.append(m) | |
| titles.append(class_names[class_id] if class_id != -1 else "-") | |
| display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r") | |
| def plot_precision_recall(AP, precisions, recalls): | |
| """Draw the precision-recall curve. | |
| AP: Average precision at IoU >= 0.5 | |
| precisions: list of precision values | |
| recalls: list of recall values | |
| """ | |
| # Plot the Precision-Recall curve | |
| _, ax = plt.subplots(1) | |
| ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP)) | |
| ax.set_ylim(0, 1.1) | |
| ax.set_xlim(0, 1.1) | |
| _ = ax.plot(recalls, precisions) | |
| def plot_overlaps( | |
| gt_class_ids, pred_class_ids, pred_scores, overlaps, class_names, threshold=0.5 | |
| ): | |
| """Draw a grid showing how ground truth objects are classified. | |
| gt_class_ids: [N] int. Ground truth class IDs | |
| pred_class_id: [N] int. Predicted class IDs | |
| pred_scores: [N] float. The probability scores of predicted classes | |
| overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes. | |
| class_names: list of all class names in the dataset | |
| threshold: Float. The prediction probability required to predict a class | |
| """ | |
| gt_class_ids = gt_class_ids[gt_class_ids != 0] | |
| pred_class_ids = pred_class_ids[pred_class_ids != 0] | |
| plt.figure(figsize=(12, 10)) | |
| plt.imshow(overlaps, interpolation="nearest", cmap=plt.cm.Blues) | |
| plt.yticks( | |
| np.arange(len(pred_class_ids)), | |
| [ | |
| "{} ({:.2f})".format(class_names[int(id)], pred_scores[i]) | |
| for i, id in enumerate(pred_class_ids) | |
| ], | |
| ) | |
| plt.xticks( | |
| np.arange(len(gt_class_ids)), | |
| [class_names[int(id)] for id in gt_class_ids], | |
| rotation=90, | |
| ) | |
| thresh = overlaps.max() / 2.0 | |
| for i, j in itertools.product(range(overlaps.shape[0]), range(overlaps.shape[1])): | |
| text = "" | |
| if overlaps[i, j] > threshold: | |
| text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong" | |
| color = ( | |
| "white" | |
| if overlaps[i, j] > thresh | |
| else "black" | |
| if overlaps[i, j] > 0 | |
| else "grey" | |
| ) | |
| plt.text( | |
| j, | |
| i, | |
| "{:.3f}\n{}".format(overlaps[i, j], text), | |
| horizontalalignment="center", | |
| verticalalignment="center", | |
| fontsize=9, | |
| color=color, | |
| ) | |
| plt.tight_layout() | |
| plt.xlabel("Ground Truth") | |
| plt.ylabel("Predictions") | |
| def draw_boxes( | |
| image, | |
| boxes=None, | |
| refined_boxes=None, | |
| masks=None, | |
| captions=None, | |
| visibilities=None, | |
| title="", | |
| ax=None, | |
| ): | |
| """Draw bounding boxes and segmentation masks with different | |
| customizations. | |
| boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates. | |
| refined_boxes: Like boxes, but draw with solid lines to show | |
| that they're the result of refining 'boxes'. | |
| masks: [N, height, width] | |
| captions: List of N titles to display on each box | |
| visibilities: (optional) List of values of 0, 1, or 2. Determine how | |
| prominent each bounding box should be. | |
| title: An optional title to show over the image | |
| ax: (optional) Matplotlib axis to draw on. | |
| """ | |
| # Number of boxes | |
| assert boxes is not None or refined_boxes is not None | |
| N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0] | |
| # Matplotlib Axis | |
| if not ax: | |
| _, ax = plt.subplots(1, figsize=(12, 12)) | |
| # Generate random colors | |
| colors = random_colors(N) | |
| # Show area outside image boundaries. | |
| margin = image.shape[0] // 10 | |
| ax.set_ylim(image.shape[0] + margin, -margin) | |
| ax.set_xlim(-margin, image.shape[1] + margin) | |
| ax.axis("off") | |
| ax.set_title(title) | |
| masked_image = image.astype(np.uint32).copy() | |
| for i in range(N): | |
| # Box visibility | |
| visibility = visibilities[i] if visibilities is not None else 1 | |
| if visibility == 0: | |
| color = "gray" | |
| style = "dotted" | |
| alpha = 0.5 | |
| elif visibility == 1: | |
| color = colors[i] | |
| style = "dotted" | |
| alpha = 1 | |
| elif visibility == 2: | |
| color = colors[i] | |
| style = "solid" | |
| alpha = 1 | |
| # Boxes | |
| if boxes is not None: | |
| if not np.any(boxes[i]): | |
| # Skip this instance. Has no bbox. Likely lost in cropping. | |
| continue | |
| y1, x1, y2, x2 = boxes[i] | |
| p = patches.Rectangle( | |
| (x1, y1), | |
| x2 - x1, | |
| y2 - y1, | |
| linewidth=2, | |
| alpha=alpha, | |
| linestyle=style, | |
| edgecolor=color, | |
| facecolor="none", | |
| ) | |
| ax.add_patch(p) | |
| # Refined boxes | |
| if refined_boxes is not None and visibility > 0: | |
| ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32) | |
| p = patches.Rectangle( | |
| (rx1, ry1), | |
| rx2 - rx1, | |
| ry2 - ry1, | |
| linewidth=2, | |
| edgecolor=color, | |
| facecolor="none", | |
| ) | |
| ax.add_patch(p) | |
| # Connect the top-left corners of the anchor and proposal | |
| if boxes is not None: | |
| ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) | |
| # Captions | |
| if captions is not None: | |
| caption = captions[i] | |
| # If there are refined boxes, display captions on them | |
| if refined_boxes is not None: | |
| y1, x1, y2, x2 = ry1, rx1, ry2, rx2 | |
| ax.text( | |
| x1, | |
| y1, | |
| caption, | |
| size=11, | |
| verticalalignment="top", | |
| color="w", | |
| backgroundcolor="none", | |
| bbox={"facecolor": color, "alpha": 0.5, "pad": 2, "edgecolor": "none"}, | |
| ) | |
| # Masks | |
| if masks is not None: | |
| mask = masks[:, :, i] | |
| masked_image = apply_mask(masked_image, mask, color) | |
| # Mask Polygon | |
| # Pad to ensure proper polygons for masks that touch image edges. | |
| padded_mask = np.zeros( | |
| (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8 | |
| ) | |
| padded_mask[1:-1, 1:-1] = mask | |
| contours = find_contours(padded_mask, 0.5) | |
| for verts in contours: | |
| # Subtract the padding and flip (y, x) to (x, y) | |
| verts = np.fliplr(verts) - 1 | |
| p = Polygon(verts, facecolor="none", edgecolor=color) | |
| ax.add_patch(p) | |
| ax.imshow(masked_image.astype(np.uint8)) | |
| def display_table(table): | |
| """Display values in a table format. | |
| table: an iterable of rows, and each row is an iterable of values. | |
| """ | |
| html = "" | |
| for row in table: | |
| row_html = "" | |
| for col in row: | |
| row_html += "<td>{:40}</td>".format(str(col)) | |
| html += "<tr>" + row_html + "</tr>" | |
| html = "<table>" + html + "</table>" | |
| IPython.display.display(IPython.display.HTML(html)) | |
| def display_weight_stats(model): | |
| """Scans all the weights in the model and returns a list of tuples | |
| that contain stats about each weight. | |
| """ | |
| layers = model.get_trainable_layers() | |
| table = [["WEIGHT NAME", "SHAPE", "MIN", "MAX", "STD"]] | |
| for l in layers: | |
| weight_values = l.get_weights() # list of Numpy arrays | |
| weight_tensors = l.weights # list of TF tensors | |
| for i, w in enumerate(weight_values): | |
| weight_name = weight_tensors[i].name | |
| # Detect problematic layers. Exclude biases of conv layers. | |
| alert = "" | |
| if w.min() == w.max() and not (l.__class__.__name__ == "Conv2D" and i == 1): | |
| alert += "<span style='color:red'>*** dead?</span>" | |
| if np.abs(w.min()) > 1000 or np.abs(w.max()) > 1000: | |
| alert += "<span style='color:red'>*** Overflow?</span>" | |
| # Add row | |
| table.append( | |
| [ | |
| weight_name + alert, | |
| str(w.shape), | |
| "{:+9.4f}".format(w.min()), | |
| "{:+10.4f}".format(w.max()), | |
| "{:+9.4f}".format(w.std()), | |
| ] | |
| ) | |
| display_table(table) | |