File size: 4,656 Bytes
8655a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Usage:
python3 -m unittest tests.test_image_utils
"""

import base64
from io import BytesIO
import os
import unittest

import numpy as np
from PIL import Image

from fastchat.utils import (
    resize_image_and_return_image_in_bytes,
    image_moderation_filter,
)
from fastchat.conversation import get_conv_template


def check_byte_size_in_mb(image_base64_str):
    return len(image_base64_str) / 1024 / 1024


def generate_random_image(target_size_mb, image_format="PNG"):
    # Convert target size from MB to bytes
    target_size_bytes = target_size_mb * 1024 * 1024

    # Estimate dimensions
    dimension = int((target_size_bytes / 3) ** 0.5)

    # Generate random pixel data
    pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8)

    # Create an image from the pixel data
    img = Image.fromarray(pixel_data)

    # Save image to a temporary file
    temp_filename = "temp_image." + image_format.lower()
    img.save(temp_filename, format=image_format)

    # Check the file size and adjust quality if needed
    while os.path.getsize(temp_filename) < target_size_bytes:
        # Increase dimensions or change compression quality
        dimension += 1
        pixel_data = np.random.randint(
            0, 256, (dimension, dimension, 3), dtype=np.uint8
        )
        img = Image.fromarray(pixel_data)
        img.save(temp_filename, format=image_format)

    return img


class DontResizeIfLessThanMaxTest(unittest.TestCase):
    def test_dont_resize_if_less_than_max(self):
        max_image_size = 5
        initial_size_mb = 0.1  # Initial image size
        img = generate_random_image(initial_size_mb)

        image_bytes = BytesIO()
        img.save(image_bytes, format="PNG")  # Save the image as JPEG
        previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        image_bytes = resize_image_and_return_image_in_bytes(
            img, max_image_size_mb=max_image_size
        )
        new_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        self.assertEqual(previous_image_size, new_image_size)


class ResizeLargeImageForModerationEndpoint(unittest.TestCase):
    def test_resize_large_image_and_send_to_moderation_filter(self):
        initial_size_mb = 6  # Initial image size which we know is greater than what the endpoint can take
        img = generate_random_image(initial_size_mb)

        nsfw_flag, csam_flag = image_moderation_filter(img)
        self.assertFalse(nsfw_flag)
        self.assertFalse(nsfw_flag)


class DontResizeIfMaxImageSizeIsNone(unittest.TestCase):
    def test_dont_resize_if_max_image_size_is_none(self):
        initial_size_mb = 0.2  # Initial image size
        img = generate_random_image(initial_size_mb)

        image_bytes = BytesIO()
        img.save(image_bytes, format="PNG")  # Save the image as JPEG
        previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        image_bytes = resize_image_and_return_image_in_bytes(
            img, max_image_size_mb=None
        )
        new_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        self.assertEqual(previous_image_size, new_image_size)


class OpenAIConversationDontResizeImage(unittest.TestCase):
    def test(self):
        conv = get_conv_template("chatgpt")
        initial_size_mb = 0.2  # Initial image size
        img = generate_random_image(initial_size_mb)
        image_bytes = BytesIO()
        img.save(image_bytes, format="PNG")  # Save the image as JPEG
        previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        resized_img = conv.convert_image_to_base64(img)
        resized_img_bytes = base64.b64decode(resized_img)
        new_image_size = check_byte_size_in_mb(resized_img_bytes)

        self.assertEqual(previous_image_size, new_image_size)


class ClaudeConversationResizesCorrectly(unittest.TestCase):
    def test(self):
        conv = get_conv_template("claude-3-haiku-20240307")
        initial_size_mb = 5  # Initial image size
        img = generate_random_image(initial_size_mb)
        image_bytes = BytesIO()
        img.save(image_bytes, format="PNG")  # Save the image as JPEG
        previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())

        resized_img = conv.convert_image_to_base64(img)
        new_base64_image_size = check_byte_size_in_mb(resized_img)
        new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img))

        self.assertLess(new_image_bytes_size, previous_image_size)
        self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb)
        self.assertLessEqual(new_base64_image_size, 5)