from typing import List, Optional, Tuple import numpy as np import torch import PIL import math def tensor_to_PIL(img: torch.Tensor) -> PIL.Image.Image: """ Converts a tensor image to a PIL Image. Args: img (torch.Tensor): The tensor image of shape [batch_size, num_channels, height, width]. Returns: A PIL Image object. """ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) return PIL.Image.fromarray(img[0].cpu().numpy(), "RGB") def get_ellipse_coords( point: Tuple[int, int], radius: int = 5 ) -> Tuple[int, int, int, int]: """ Returns the coordinates of an ellipse centered at the given point. Args: point (Tuple[int, int]): The center point of the ellipse. radius (int): The radius of the ellipse. Returns: A tuple containing the coordinates of the ellipse in the format (x_min, y_min, x_max, y_max). """ center = point return ( center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius, ) def draw_handle_target_points( img: PIL.Image.Image, handle_points: List[Tuple[int, int]], target_points: List[Tuple[int, int]], radius: int = 5): """ Draws handle and target points with arrow pointing towards the target point. Args: img (PIL.Image.Image): The image to draw on. handle_points (List[Tuple[int, int]]): A list of handle [x,y] points. target_points (List[Tuple[int, int]]): A list of target [x,y] points. radius (int): The radius of the handle and target points. """ if len(handle_points) == len(target_points) + 1: target_points.append(None) draw = PIL.ImageDraw.Draw(img) for handle_point, target_point in zip(handle_points, target_points): # Draw the handle point handle_coords = get_ellipse_coords(handle_point, radius) draw.ellipse(handle_coords, fill="red") if target_point: # Draw the target point target_coords = get_ellipse_coords(target_point, radius) draw.ellipse(target_coords, fill="blue") # Draw arrow head arrow_head_length = 10.0 # Compute the direction vector of the line dx = target_point[0] - handle_point[0] dy = target_point[1] - handle_point[1] angle = math.atan2(dy, dx) # Shorten the target point by the length of the arrowhead shortened_target_point = ( target_point[0] - arrow_head_length * math.cos(angle), target_point[1] - arrow_head_length * math.sin(angle), ) # Draw the arrow (main line) draw.line([tuple(handle_point), shortened_target_point], fill='green', width=2) # Compute the points for the arrowhead arrow_point1 = ( target_point[0] - arrow_head_length * math.cos(angle - math.pi / 6), target_point[1] - arrow_head_length * math.sin(angle - math.pi / 6), ) arrow_point2 = ( target_point[0] - arrow_head_length * math.cos(angle + math.pi / 6), target_point[1] - arrow_head_length * math.sin(angle + math.pi / 6), ) # Draw the arrowhead draw.polygon([tuple(target_point), arrow_point1, arrow_point2], fill='green') # # Draw shifted coordinates handle + d_i # for points in all_shifted_coordinates: # if not torch.isnan(points).any(): # coords = utils.get_ellipse_coords(points.mean(0).flip(-1).cpu().long().numpy().tolist(), 7) # draw.ellipse(coords, fill="orange") def create_circular_mask( h: int, w: int, center: Optional[Tuple[int, int]] = None, radius: Optional[int] = None, ) -> torch.Tensor: """ Create a circular mask tensor. Args: h (int): The height of the mask tensor. w (int): The width of the mask tensor. center (Optional[Tuple[int, int]]): The center of the circle as a tuple (y, x). If None, the middle of the image is used. radius (Optional[int]): The radius of the circle. If None, the smallest distance between the center and image walls is used. Returns: A boolean tensor of shape [h, w] representing the circular mask. """ if center is None: # use the middle of the image center = (int(h / 2), int(w / 2)) if radius is None: # use the smallest distance between the center and image walls radius = min(center[0], center[1], h - center[0], w - center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2) mask = dist_from_center <= radius mask = torch.from_numpy(mask).bool() return mask def create_square_mask( height: int, width: int, center: list, radius: int ) -> torch.Tensor: """Create a square mask tensor. Args: height (int): The height of the mask. width (int): The width of the mask. center (list): The center of the square mask as a list of two integers. Order [y,x] radius (int): The radius of the square mask. Returns: torch.Tensor: The square mask tensor of shape (1, 1, height, width). Raises: ValueError: If the center or radius is invalid. """ if not isinstance(center, list) or len(center) != 2: raise ValueError("center must be a list of two integers") if not isinstance(radius, int) or radius <= 0: raise ValueError("radius must be a positive integer") if ( center[0] < radius or center[0] >= height - radius or center[1] < radius or center[1] >= width - radius ): raise ValueError("center and radius must be within the bounds of the mask") mask = torch.zeros((height, width), dtype=torch.float32) x1 = int(center[1]) - radius x2 = int(center[1]) + radius y1 = int(center[0]) - radius y2 = int(center[0]) + radius mask[y1 : y2 + 1, x1 : x2 + 1] = 1.0 return mask.bool()