LSP-DETR / image_processing.py
matejpekar's picture
Upload processor
3b3a0ee verified
import numpy as np
import torch
from PIL import Image, ImageDraw
from torch import Tensor
from transformers import BaseImageProcessorFast
class LSPDetrImageProcessor(BaseImageProcessorFast):
image_mean = (0.485, 0.456, 0.406)
image_std = (0.229, 0.224, 0.225)
do_rescale = True
do_normalize = True
def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]:
"""Converts the raw output into polygons.
Returns:
A list of dictionaries, each containing:
- "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons.
- "labels": A tensor of shape (N,) representing the labels for each polygon.
"""
radial_distances = outputs["radial_distances"].expm1()
t = torch.linspace(
0, 1, radial_distances.size(-1) + 1, device=radial_distances.device
)[:-1]
cos = torch.cos(2 * torch.pi * t)
sin = torch.sin(2 * torch.pi * t)
polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
polygons = outputs["absolute_points"].unsqueeze(-2) + polar
labels = outputs["logits"].argmax(dim=-1)
non_no_object_indices = labels != outputs["logits"].size(-1) - 1
return [
{
"polygons": polygons[b, indices],
"labels": labels[b, indices],
"embeddings": outputs["embeddings"][b, indices],
}
for b, indices in enumerate(non_no_object_indices)
]
def post_process_instance(
self,
results: list[dict[str, Tensor]],
height: int,
width: int,
) -> list[dict[str, Tensor]]:
"""Converts the output into actual instance segmentation predictions.
Args:
results: Results list obtained by `post_process`, to which "masks" results will be added.
height: Height of the input image.
width: Width of the input image.
"""
for i, result in enumerate(results):
masks = torch.zeros(
(len(result["polygons"]), height, width),
dtype=torch.bool,
device=result["polygons"].device,
)
for j, polygon in enumerate(result["polygons"]):
img = Image.fromarray(masks[j].cpu().numpy())
canvas = ImageDraw.Draw(img)
canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1)
masks[j] = torch.tensor(np.asarray(img))
results[i]["masks"] = masks
return results