DragGAN_Streamlit / utils.py
Evan Davis
Add files from GH
da180b6
from typing import List, Optional, Tuple
import numpy as np
import torch
import PIL
import math
def tensor_to_PIL(img: torch.Tensor) -> PIL.Image.Image:
"""
Converts a tensor image to a PIL Image.
Args:
img (torch.Tensor): The tensor image of shape [batch_size, num_channels, height, width].
Returns:
A PIL Image object.
"""
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return PIL.Image.fromarray(img[0].cpu().numpy(), "RGB")
def get_ellipse_coords(
point: Tuple[int, int], radius: int = 5
) -> Tuple[int, int, int, int]:
"""
Returns the coordinates of an ellipse centered at the given point.
Args:
point (Tuple[int, int]): The center point of the ellipse.
radius (int): The radius of the ellipse.
Returns:
A tuple containing the coordinates of the ellipse in the format (x_min, y_min, x_max, y_max).
"""
center = point
return (
center[0] - radius,
center[1] - radius,
center[0] + radius,
center[1] + radius,
)
def draw_handle_target_points(
img: PIL.Image.Image,
handle_points: List[Tuple[int, int]],
target_points: List[Tuple[int, int]],
radius: int = 5):
"""
Draws handle and target points with arrow pointing towards the target point.
Args:
img (PIL.Image.Image): The image to draw on.
handle_points (List[Tuple[int, int]]): A list of handle [x,y] points.
target_points (List[Tuple[int, int]]): A list of target [x,y] points.
radius (int): The radius of the handle and target points.
"""
if len(handle_points) == len(target_points) + 1:
target_points.append(None)
draw = PIL.ImageDraw.Draw(img)
for handle_point, target_point in zip(handle_points, target_points):
# Draw the handle point
handle_coords = get_ellipse_coords(handle_point, radius)
draw.ellipse(handle_coords, fill="red")
if target_point:
# Draw the target point
target_coords = get_ellipse_coords(target_point, radius)
draw.ellipse(target_coords, fill="blue")
# Draw arrow head
arrow_head_length = 10.0
# Compute the direction vector of the line
dx = target_point[0] - handle_point[0]
dy = target_point[1] - handle_point[1]
angle = math.atan2(dy, dx)
# Shorten the target point by the length of the arrowhead
shortened_target_point = (
target_point[0] - arrow_head_length * math.cos(angle),
target_point[1] - arrow_head_length * math.sin(angle),
)
# Draw the arrow (main line)
draw.line([tuple(handle_point), shortened_target_point], fill='green', width=2)
# Compute the points for the arrowhead
arrow_point1 = (
target_point[0] - arrow_head_length * math.cos(angle - math.pi / 6),
target_point[1] - arrow_head_length * math.sin(angle - math.pi / 6),
)
arrow_point2 = (
target_point[0] - arrow_head_length * math.cos(angle + math.pi / 6),
target_point[1] - arrow_head_length * math.sin(angle + math.pi / 6),
)
# Draw the arrowhead
draw.polygon([tuple(target_point), arrow_point1, arrow_point2], fill='green')
# # Draw shifted coordinates handle + d_i
# for points in all_shifted_coordinates:
# if not torch.isnan(points).any():
# coords = utils.get_ellipse_coords(points.mean(0).flip(-1).cpu().long().numpy().tolist(), 7)
# draw.ellipse(coords, fill="orange")
def create_circular_mask(
h: int,
w: int,
center: Optional[Tuple[int, int]] = None,
radius: Optional[int] = None,
) -> torch.Tensor:
"""
Create a circular mask tensor.
Args:
h (int): The height of the mask tensor.
w (int): The width of the mask tensor.
center (Optional[Tuple[int, int]]): The center of the circle as a tuple (y, x). If None, the middle of the image is used.
radius (Optional[int]): The radius of the circle. If None, the smallest distance between the center and image walls is used.
Returns:
A boolean tensor of shape [h, w] representing the circular mask.
"""
if center is None: # use the middle of the image
center = (int(h / 2), int(w / 2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], h - center[0], w - center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2)
mask = dist_from_center <= radius
mask = torch.from_numpy(mask).bool()
return mask
def create_square_mask(
height: int, width: int, center: list, radius: int
) -> torch.Tensor:
"""Create a square mask tensor.
Args:
height (int): The height of the mask.
width (int): The width of the mask.
center (list): The center of the square mask as a list of two integers. Order [y,x]
radius (int): The radius of the square mask.
Returns:
torch.Tensor: The square mask tensor of shape (1, 1, height, width).
Raises:
ValueError: If the center or radius is invalid.
"""
if not isinstance(center, list) or len(center) != 2:
raise ValueError("center must be a list of two integers")
if not isinstance(radius, int) or radius <= 0:
raise ValueError("radius must be a positive integer")
if (
center[0] < radius
or center[0] >= height - radius
or center[1] < radius
or center[1] >= width - radius
):
raise ValueError("center and radius must be within the bounds of the mask")
mask = torch.zeros((height, width), dtype=torch.float32)
x1 = int(center[1]) - radius
x2 = int(center[1]) + radius
y1 = int(center[0]) - radius
y2 = int(center[0]) + radius
mask[y1 : y2 + 1, x1 : x2 + 1] = 1.0
return mask.bool()