from typing import List, Optional import torch import torch.nn.functional as F import numpy as np import torch.distributed as dist from PIL import Image, ImageDraw import matplotlib.pyplot as plt import diffdist.functional as diff_dist from typing import List, Optional from torchvision.ops import masks_to_boxes import io def visualize_oneformer_masks_on_image( image: torch.Tensor, masks: List[torch.Tensor], classes: List[str], save_path: Optional[str] = None, ): """ inputs: image: torch.Tensor of shape (3, H, W) masks: List[torch.Tensor] of len NUM_MASKS classes: List[str] of len NUM_MASKS save_path: Optional[str] path to save the visualization returns: pil_image: PIL.Image with masks overlayed on the image """ def _show_mask(mask, class_name, ax, random_color=False): mask = mask.cpu() box = masks_to_boxes(mask.unsqueeze(0))[0] x0, y0, x1, y1 = box x = (x0 + x1) / 2 y = (y0 + y1) / 2 if random_color: color = np.concatenate( [np.random.random(3), np.array([0.6])], axis=0 ) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) ax.text(x, y, class_name, fontsize="x-small") # Create a matplotlib figure fig, ax = plt.subplots() ax.imshow(np.array(image)) # Convert to HWC format for plt ax.set_autoscale_on(False) for mask, class_name in zip(masks, classes): _show_mask(mask, class_name, ax=ax, random_color=True) plt.axis("off") plt.tight_layout() # Save figure to a BytesIO object and convert to PIL.Image buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) buf.seek(0) pil_image = Image.open(buf) # Optionally save the PIL image if save_path is not None: pil_image.save(save_path) plt.close(fig) return pil_image def oneformer_prepare_panoptic_instance_prediction( segmentation: torch.Tensor, segments_info: dict, oneformer ): masks = [] classes = [] for segment in segments_info: id = segment["id"] label_id = segment["label_id"] label = oneformer.config.id2label[label_id] mask = segmentation == id masks.append(mask.float()) classes.append(label) return masks, classes def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def dist_collect(x): """ collect all tensor from all GPUs args: x: shape (mini_batch, ...) returns: shape (mini_batch * num_gpu, ...) """ x = x.contiguous() out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())] out_list = diff_dist.all_gather(out_list, x) return torch.cat(out_list, dim=0).contiguous() def calculate_contrastive_loss(preds, targets, logit_scale): batch_size = preds.shape[0] if is_dist_avail_and_initialized(): labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank() else: labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) preds = F.normalize(preds.flatten(1), dim=-1) targets = F.normalize(targets.flatten(1), dim=-1) if is_dist_avail_and_initialized(): logits_per_img = preds @ dist_collect(targets).t() else: logits_per_img = preds @ targets.t() logit_scale = torch.clamp(logit_scale.exp(), max=100) loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none") return loss_contrastive def silog_loss(depth_est, depth_gt, variance_focus=0.5): mask = (depth_gt > 0).detach() if mask.sum() == 0: return torch.tensor(0.0).to(depth_est) d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) loss = torch.sqrt(torch.pow(d, 2).mean() - variance_focus * torch.pow(d.mean(), 2)) * 1.0 return loss def make_grid(images, pil_images): # Assuming each image is the same size new_images = [] new_captions = [] for image, pil_image in zip(images, pil_images): new_images.append(image) pil_image = pil_image.resize((image.size[0], image.size[1])) new_images.append(pil_image) new_captions.append("Predicted") new_captions.append("GT") images = new_images captions = new_captions width, height = images[0].size font_size = 14 caption_height = font_size + 10 # Calculate the size of the final image images_per_row = min(len(images), 16) # Round up for odd number of images row_count = (len(images) + 1) // images_per_row total_width = width * images_per_row total_height = (height + caption_height) * row_count # Create a new blank image new_image = Image.new("RGB", (total_width, total_height), "white") draw = ImageDraw.Draw(new_image) for i, (image, caption) in enumerate(zip(images, captions)): row = i // images_per_row col = i % images_per_row x_offset = col * width y_offset = row * (height + caption_height) new_image.paste(image, (x_offset, y_offset)) text_position = (x_offset + 10, y_offset + height) draw.text(text_position, caption, fill="red", font_size=font_size) return new_image def visualize_masks(anns, rgb_image): if len(anns) == 0: return rgb_image sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img_array = np.array(rgb_image) masked_image = np.ones(img_array.shape) for ann in sorted_anns: m = ann['segmentation'] color_mask = np.random.random(3) masked_image[m] = (color_mask * 255).astype(np.uint8) img_array = img_array * 0.35 + masked_image * 0.65 img_array = img_array.astype(np.uint8) ax.imshow(img_array) overlayed_img = Image.fromarray(img_array) return overlayed_img