File size: 8,165 Bytes
1ed7deb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import math
import random
import warnings
from itertools import cycle
from typing import List, Optional, Tuple, Callable

from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
from more_itertools.recipes import grouper
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
    additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
    absolute_bbox, rescale_annotations
from taming.data.helper_types import BoundingBox, Annotation
from taming.data.image_transforms import convert_pil_to_tensor
from torch import LongTensor, Tensor


class ObjectsCenterPointsConditionalBuilder:
    def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
                 use_group_parameter: bool, use_additional_parameters: bool):
        self.no_object_classes = no_object_classes
        self.no_max_objects = no_max_objects
        self.no_tokens = no_tokens
        self.encode_crop = encode_crop
        self.no_sections = int(math.sqrt(self.no_tokens))
        self.use_group_parameter = use_group_parameter
        self.use_additional_parameters = use_additional_parameters

    @property
    def none(self) -> int:
        return self.no_tokens - 1

    @property
    def object_descriptor_length(self) -> int:
        return 2

    @property
    def embedding_dim(self) -> int:
        extra_length = 2 if self.encode_crop else 0
        return self.no_max_objects * self.object_descriptor_length + extra_length

    def tokenize_coordinates(self, x: float, y: float) -> int:
        """
        Express 2d coordinates with one number.
        Example: assume self.no_tokens = 16, then no_sections = 4:
        0  0  0  0
        0  0  #  0
        0  0  0  0
        0  0  0  x
        Then the # position corresponds to token 6, the x position to token 15.
        @param x: float in [0, 1]
        @param y: float in [0, 1]
        @return: discrete tokenized coordinate
        """
        x_discrete = int(round(x * (self.no_sections - 1)))
        y_discrete = int(round(y * (self.no_sections - 1)))
        return y_discrete * self.no_sections + x_discrete

    def coordinates_from_token(self, token: int) -> (float, float):
        x = token % self.no_sections
        y = token // self.no_sections
        return x / (self.no_sections - 1), y / (self.no_sections - 1)

    def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
        x0, y0 = self.coordinates_from_token(token1)
        x1, y1 = self.coordinates_from_token(token2)
        return x0, y0, x1 - x0, y1 - y0

    def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
        return self.tokenize_coordinates(bbox[0], bbox[1]), \
               self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])

    def inverse_build(self, conditional: LongTensor) \
            -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
        conditional_list = conditional.tolist()
        crop_coordinates = None
        if self.encode_crop:
            crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
            conditional_list = conditional_list[:-2]
        table_of_content = grouper(conditional_list, self.object_descriptor_length)
        assert conditional.shape[0] == self.embedding_dim
        return [
            (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
            for object_tuple in table_of_content if object_tuple[0] != self.none
        ], crop_coordinates

    def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
             line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
        plot = pil_image.new('RGB', figure_size, WHITE)
        draw = pil_img_draw.Draw(plot)
        circle_size = get_circle_size(figure_size)
        font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
                                  size=get_plot_font_size(font_size, figure_size))
        width, height = plot.size
        description, crop_coordinates = self.inverse_build(conditional)
        for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
            x_abs, y_abs = x * width, y * height
            ann = self.representation_to_annotation(representation)
            label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
            ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
            draw.ellipse(ellipse_bbox, fill=color, width=0)
            draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
        if crop_coordinates is not None:
            draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
        return convert_pil_to_tensor(plot) / 127.5 - 1.

    def object_representation(self, annotation: Annotation) -> int:
        modifier = 0
        if self.use_group_parameter:
            modifier |= 1 * (annotation.is_group_of is True)
        if self.use_additional_parameters:
            modifier |= 2 * (annotation.is_occluded is True)
            modifier |= 4 * (annotation.is_depiction is True)
            modifier |= 8 * (annotation.is_inside is True)
        return annotation.category_no + self.no_object_classes * modifier

    def representation_to_annotation(self, representation: int) -> Annotation:
        category_no = representation % self.no_object_classes
        modifier = representation // self.no_object_classes
        # noinspection PyTypeChecker
        return Annotation(
            area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
            category_no=category_no,
            is_group_of=bool((modifier & 1) * self.use_group_parameter),
            is_occluded=bool((modifier & 2) * self.use_additional_parameters),
            is_depiction=bool((modifier & 4) * self.use_additional_parameters),
            is_inside=bool((modifier & 8) * self.use_additional_parameters)
        )

    def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
        return list(self.token_pair_from_bbox(crop_coordinates))

    def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
        object_tuples = [
            (self.object_representation(a),
             self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
            for a in annotations
        ]
        empty_tuple = (self.none, self.none)
        object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
        return object_tuples

    def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
            -> LongTensor:
        if len(annotations) == 0:
            warnings.warn('Did not receive any annotations.')
        if len(annotations) > self.no_max_objects:
            warnings.warn('Received more annotations than allowed.')
            annotations = annotations[:self.no_max_objects]

        if not crop_coordinates:
            crop_coordinates = FULL_CROP

        random.shuffle(annotations)
        annotations = filter_annotations(annotations, crop_coordinates)
        if self.encode_crop:
            annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
            if horizontal_flip:
                crop_coordinates = horizontally_flip_bbox(crop_coordinates)
            extra = self._crop_encoder(crop_coordinates)
        else:
            annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
            extra = []

        object_tuples = self._make_object_descriptors(annotations)
        flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
        assert len(flattened) == self.embedding_dim
        assert all(0 <= value < self.no_tokens for value in flattened)
        return LongTensor(flattened)