File size: 1,449 Bytes
23db031
 
 
 
7c11918
23db031
7c11918
 
23db031
7c11918
 
 
 
23db031
 
 
7c11918
23db031
 
 
 
 
 
 
7c11918
23db031
7c11918
23db031
7c11918
 
 
 
 
 
 
 
23db031
2007b83
7c11918
23db031
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import List, Union

import torch
from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image


def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
    """
    Draw bounding boxes on an image.

    Args:
    - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
    - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
      where coordinates are normalized [0, 1].
    """
    # Convert tensor image to PIL Image if necessary
    if isinstance(img, torch.Tensor):
        if img.dim() > 3:
            logger.info("Multi-frame tensor detected, using the first image.")
            img = img[0]
            bboxes = bboxes[0]
        img = to_pil_image(img)

    draw = ImageDraw.Draw(img)
    width, height = img.size
    font = ImageFont.load_default(30)

    for bbox in bboxes:
        class_id, x_min, y_min, x_max, y_max = bbox
        x_min = x_min * width
        x_max = x_max * width
        y_min = y_min * height
        y_max = y_max * height
        shape = [(x_min, y_min), (x_max, y_max)]
        draw.rectangle(shape, outline="red", width=3)
        draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")

    img.save("visualize.jpg")  # Save the image with annotations
    logger.info("Saved visualize image at visualize.png")