Spaces:
Paused
Paused
| """ | |
| coding=utf-8 | |
| Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal | |
| Adapted From Facebook Inc, Detectron2 | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License.import copy | |
| """ | |
| import colorsys | |
| import io | |
| import cv2 | |
| import matplotlib as mpl | |
| import matplotlib.colors as mplc | |
| import matplotlib.figure as mplfigure | |
| import numpy as np | |
| import torch | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| from utils import img_tensorize | |
| _SMALL_OBJ = 1000 | |
| class SingleImageViz: | |
| def __init__( | |
| self, | |
| img, | |
| scale=1.2, | |
| edgecolor="g", | |
| alpha=0.5, | |
| linestyle="-", | |
| saveas="test_out.jpg", | |
| rgb=True, | |
| pynb=False, | |
| id2obj=None, | |
| id2attr=None, | |
| pad=0.7, | |
| ): | |
| """ | |
| img: an RGB image of shape (H, W, 3). | |
| """ | |
| if isinstance(img, torch.Tensor): | |
| img = img.numpy().astype("np.uint8") | |
| if isinstance(img, str): | |
| img = img_tensorize(img) | |
| assert isinstance(img, np.ndarray) | |
| width, height = img.shape[1], img.shape[0] | |
| fig = mplfigure.Figure(frameon=False) | |
| dpi = fig.get_dpi() | |
| width_in = (width * scale + 1e-2) / dpi | |
| height_in = (height * scale + 1e-2) / dpi | |
| fig.set_size_inches(width_in, height_in) | |
| ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) | |
| ax.axis("off") | |
| ax.set_xlim(0.0, width) | |
| ax.set_ylim(height) | |
| self.saveas = saveas | |
| self.rgb = rgb | |
| self.pynb = pynb | |
| self.img = img | |
| self.edgecolor = edgecolor | |
| self.alpha = 0.5 | |
| self.linestyle = linestyle | |
| self.font_size = int(np.sqrt(min(height, width)) * scale // 3) | |
| self.width = width | |
| self.height = height | |
| self.scale = scale | |
| self.fig = fig | |
| self.ax = ax | |
| self.pad = pad | |
| self.id2obj = id2obj | |
| self.id2attr = id2attr | |
| self.canvas = FigureCanvasAgg(fig) | |
| def add_box(self, box, color=None): | |
| if color is None: | |
| color = self.edgecolor | |
| (x0, y0, x1, y1) = box | |
| width = x1 - x0 | |
| height = y1 - y0 | |
| self.ax.add_patch( | |
| mpl.patches.Rectangle( | |
| (x0, y0), | |
| width, | |
| height, | |
| fill=False, | |
| edgecolor=color, | |
| linewidth=self.font_size // 3, | |
| alpha=self.alpha, | |
| linestyle=self.linestyle, | |
| ) | |
| ) | |
| def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None): | |
| if len(boxes.shape) > 2: | |
| boxes = boxes[0] | |
| if len(obj_ids.shape) > 1: | |
| obj_ids = obj_ids[0] | |
| if len(obj_scores.shape) > 1: | |
| obj_scores = obj_scores[0] | |
| if len(attr_ids.shape) > 1: | |
| attr_ids = attr_ids[0] | |
| if len(attr_scores.shape) > 1: | |
| attr_scores = attr_scores[0] | |
| if isinstance(boxes, torch.Tensor): | |
| boxes = boxes.numpy() | |
| if isinstance(boxes, list): | |
| boxes = np.array(boxes) | |
| assert isinstance(boxes, np.ndarray) | |
| areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) | |
| sorted_idxs = np.argsort(-areas).tolist() | |
| boxes = boxes[sorted_idxs] if boxes is not None else None | |
| obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None | |
| obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None | |
| attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None | |
| attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None | |
| assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))] | |
| assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] | |
| if obj_ids is not None: | |
| labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores) | |
| for i in range(len(boxes)): | |
| color = assigned_colors[i] | |
| self.add_box(boxes[i], color) | |
| self.draw_labels(labels[i], boxes[i], color) | |
| def draw_labels(self, label, box, color): | |
| x0, y0, x1, y1 = box | |
| text_pos = (x0, y0) | |
| instance_area = (y1 - y0) * (x1 - x0) | |
| small = _SMALL_OBJ * self.scale | |
| if instance_area < small or y1 - y0 < 40 * self.scale: | |
| if y1 >= self.height - 5: | |
| text_pos = (x1, y0) | |
| else: | |
| text_pos = (x0, y1) | |
| height_ratio = (y1 - y0) / np.sqrt(self.height * self.width) | |
| lighter_color = self._change_color_brightness(color, brightness_factor=0.7) | |
| font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) | |
| font_size *= 0.75 * self.font_size | |
| self.draw_text( | |
| text=label, | |
| position=text_pos, | |
| color=lighter_color, | |
| ) | |
| def draw_text( | |
| self, | |
| text, | |
| position, | |
| color="g", | |
| ha="left", | |
| ): | |
| rotation = 0 | |
| font_size = self.font_size | |
| color = np.maximum(list(mplc.to_rgb(color)), 0.2) | |
| color[np.argmax(color)] = max(0.8, np.max(color)) | |
| bbox = { | |
| "facecolor": "black", | |
| "alpha": self.alpha, | |
| "pad": self.pad, | |
| "edgecolor": "none", | |
| } | |
| x, y = position | |
| self.ax.text( | |
| x, | |
| y, | |
| text, | |
| size=font_size * self.scale, | |
| family="sans-serif", | |
| bbox=bbox, | |
| verticalalignment="top", | |
| horizontalalignment=ha, | |
| color=color, | |
| zorder=10, | |
| rotation=rotation, | |
| ) | |
| def save(self, saveas=None): | |
| if saveas is None: | |
| saveas = self.saveas | |
| if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"): | |
| cv2.imwrite( | |
| saveas, | |
| self._get_buffer()[:, :, ::-1], | |
| ) | |
| else: | |
| self.fig.savefig(saveas) | |
| def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores): | |
| labels = [self.id2obj[i] for i in classes] | |
| attr_labels = [self.id2attr[i] for i in attr_classes] | |
| labels = [ | |
| f"{label} {score:.2f} {attr} {attr_score:.2f}" | |
| for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores) | |
| ] | |
| return labels | |
| def _create_text_labels(self, classes, scores): | |
| labels = [self.id2obj[i] for i in classes] | |
| if scores is not None: | |
| if labels is None: | |
| labels = ["{:.0f}%".format(s * 100) for s in scores] | |
| else: | |
| labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)] | |
| return labels | |
| def _random_color(self, maximum=255): | |
| idx = np.random.randint(0, len(_COLORS)) | |
| ret = _COLORS[idx] * maximum | |
| if not self.rgb: | |
| ret = ret[::-1] | |
| return ret | |
| def _get_buffer(self): | |
| if not self.pynb: | |
| s, (width, height) = self.canvas.print_to_buffer() | |
| if (width, height) != (self.width, self.height): | |
| img = cv2.resize(self.img, (width, height)) | |
| else: | |
| img = self.img | |
| else: | |
| buf = io.BytesIO() # works for cairo backend | |
| self.canvas.print_rgba(buf) | |
| width, height = self.width, self.height | |
| s = buf.getvalue() | |
| img = self.img | |
| buffer = np.frombuffer(s, dtype="uint8") | |
| img_rgba = buffer.reshape(height, width, 4) | |
| rgb, alpha = np.split(img_rgba, [3], axis=2) | |
| try: | |
| import numexpr as ne # fuse them with numexpr | |
| visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)") | |
| except ImportError: | |
| alpha = alpha.astype("float32") / 255.0 | |
| visualized_image = img * (1 - alpha) + rgb * alpha | |
| return visualized_image.astype("uint8") | |
| def _change_color_brightness(self, color, brightness_factor): | |
| assert brightness_factor >= -1.0 and brightness_factor <= 1.0 | |
| color = mplc.to_rgb(color) | |
| polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) | |
| modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) | |
| modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness | |
| modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness | |
| modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) | |
| return modified_color | |
| # Color map | |
| _COLORS = ( | |
| np.array( | |
| [ | |
| 0.000, | |
| 0.447, | |
| 0.741, | |
| 0.850, | |
| 0.325, | |
| 0.098, | |
| 0.929, | |
| 0.694, | |
| 0.125, | |
| 0.494, | |
| 0.184, | |
| 0.556, | |
| 0.466, | |
| 0.674, | |
| 0.188, | |
| 0.301, | |
| 0.745, | |
| 0.933, | |
| 0.635, | |
| 0.078, | |
| 0.184, | |
| 0.300, | |
| 0.300, | |
| 0.300, | |
| 0.600, | |
| 0.600, | |
| 0.600, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.500, | |
| 0.000, | |
| 0.749, | |
| 0.749, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.667, | |
| 0.000, | |
| 1.000, | |
| 0.333, | |
| 0.333, | |
| 0.000, | |
| 0.333, | |
| 0.667, | |
| 0.000, | |
| 0.333, | |
| 1.000, | |
| 0.000, | |
| 0.667, | |
| 0.333, | |
| 0.000, | |
| 0.667, | |
| 0.667, | |
| 0.000, | |
| 0.667, | |
| 1.000, | |
| 0.000, | |
| 1.000, | |
| 0.333, | |
| 0.000, | |
| 1.000, | |
| 0.667, | |
| 0.000, | |
| 1.000, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 0.333, | |
| 0.500, | |
| 0.000, | |
| 0.667, | |
| 0.500, | |
| 0.000, | |
| 1.000, | |
| 0.500, | |
| 0.333, | |
| 0.000, | |
| 0.500, | |
| 0.333, | |
| 0.333, | |
| 0.500, | |
| 0.333, | |
| 0.667, | |
| 0.500, | |
| 0.333, | |
| 1.000, | |
| 0.500, | |
| 0.667, | |
| 0.000, | |
| 0.500, | |
| 0.667, | |
| 0.333, | |
| 0.500, | |
| 0.667, | |
| 0.667, | |
| 0.500, | |
| 0.667, | |
| 1.000, | |
| 0.500, | |
| 1.000, | |
| 0.000, | |
| 0.500, | |
| 1.000, | |
| 0.333, | |
| 0.500, | |
| 1.000, | |
| 0.667, | |
| 0.500, | |
| 1.000, | |
| 1.000, | |
| 0.500, | |
| 0.000, | |
| 0.333, | |
| 1.000, | |
| 0.000, | |
| 0.667, | |
| 1.000, | |
| 0.000, | |
| 1.000, | |
| 1.000, | |
| 0.333, | |
| 0.000, | |
| 1.000, | |
| 0.333, | |
| 0.333, | |
| 1.000, | |
| 0.333, | |
| 0.667, | |
| 1.000, | |
| 0.333, | |
| 1.000, | |
| 1.000, | |
| 0.667, | |
| 0.000, | |
| 1.000, | |
| 0.667, | |
| 0.333, | |
| 1.000, | |
| 0.667, | |
| 0.667, | |
| 1.000, | |
| 0.667, | |
| 1.000, | |
| 1.000, | |
| 1.000, | |
| 0.000, | |
| 1.000, | |
| 1.000, | |
| 0.333, | |
| 1.000, | |
| 1.000, | |
| 0.667, | |
| 1.000, | |
| 0.333, | |
| 0.000, | |
| 0.000, | |
| 0.500, | |
| 0.000, | |
| 0.000, | |
| 0.667, | |
| 0.000, | |
| 0.000, | |
| 0.833, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 0.000, | |
| 0.167, | |
| 0.000, | |
| 0.000, | |
| 0.333, | |
| 0.000, | |
| 0.000, | |
| 0.500, | |
| 0.000, | |
| 0.000, | |
| 0.667, | |
| 0.000, | |
| 0.000, | |
| 0.833, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 0.000, | |
| 0.167, | |
| 0.000, | |
| 0.000, | |
| 0.333, | |
| 0.000, | |
| 0.000, | |
| 0.500, | |
| 0.000, | |
| 0.000, | |
| 0.667, | |
| 0.000, | |
| 0.000, | |
| 0.833, | |
| 0.000, | |
| 0.000, | |
| 1.000, | |
| 0.000, | |
| 0.000, | |
| 0.000, | |
| 0.143, | |
| 0.143, | |
| 0.143, | |
| 0.857, | |
| 0.857, | |
| 0.857, | |
| 1.000, | |
| 1.000, | |
| 1.000, | |
| ] | |
| ) | |
| .astype(np.float32) | |
| .reshape(-1, 3) | |
| ) | |