File size: 3,240 Bytes
ec50e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch

from transformers import AutoProcessor
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from typing import List
from PIL import Image


PROCESSOR = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2",
    token=os.environ["HF_AUTH_TOKEN"],
)


def convert_to_rgb(image):
    # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
    # for transparent images. The call to `alpha_composite` handles this case
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite


def custom_transform(x):
    x = convert_to_rgb(x)
    x = to_numpy_array(x)

    height, width = x.shape[:2]
    aspect_ratio = width / height
    if width >= height and width > 980:
        width = 980
        height = int(width / aspect_ratio)
    elif height > width and height > 980:
        height = 980
        width = int(height * aspect_ratio)
    width = max(width, 378)
    height = max(height, 378)

    x = resize(x, (height, width), resample=PILImageResampling.BILINEAR)
    x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
    x = PROCESSOR.image_processor.normalize(
        x,
        mean=PROCESSOR.image_processor.image_mean,
        std=PROCESSOR.image_processor.image_std
    )
    x = to_channel_dimension_format(x, ChannelDimension.FIRST)
    x = torch.tensor(x)
    return x


def create_model_inputs(
        input_texts: List[str],
        image_lists: List[List[Image.Image]],
    ):
    """
    All this logic will eventually be handled inside the model processor.
    """
    inputs = PROCESSOR.tokenizer(
        input_texts,
        return_tensors="pt",
        add_special_tokens=False,
        padding=True,
    )

    output_images = [
        [PROCESSOR.image_processor(img, transform=custom_transform) for img in im_list]
        for im_list in image_lists
    ]
    total_batch_size = len(output_images)
    max_num_images = max([len(img_l) for img_l in output_images])
    if max_num_images > 0:
        max_height = max([i.size(2) for img_l in output_images for i in img_l])
        max_width = max([i.size(3) for img_l in output_images for i in img_l])
        padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
        padded_pixel_attention_masks = torch.zeros(
            total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
        )
        for batch_idx, img_l in enumerate(output_images):
            for img_idx, img in enumerate(img_l):
                im_height, im_width = img.size()[2:]
                padded_image_tensor[batch_idx, img_idx, :, :im_height, :im_width] = img
                padded_pixel_attention_masks[batch_idx, img_idx, :im_height, :im_width] = True

        inputs["pixel_values"] = padded_image_tensor
        inputs["pixel_attention_mask"] = padded_pixel_attention_masks

    return inputs