Spaces:
Runtime error
Runtime error
File size: 4,656 Bytes
8655a4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""
Usage:
python3 -m unittest tests.test_image_utils
"""
import base64
from io import BytesIO
import os
import unittest
import numpy as np
from PIL import Image
from fastchat.utils import (
resize_image_and_return_image_in_bytes,
image_moderation_filter,
)
from fastchat.conversation import get_conv_template
def check_byte_size_in_mb(image_base64_str):
return len(image_base64_str) / 1024 / 1024
def generate_random_image(target_size_mb, image_format="PNG"):
# Convert target size from MB to bytes
target_size_bytes = target_size_mb * 1024 * 1024
# Estimate dimensions
dimension = int((target_size_bytes / 3) ** 0.5)
# Generate random pixel data
pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8)
# Create an image from the pixel data
img = Image.fromarray(pixel_data)
# Save image to a temporary file
temp_filename = "temp_image." + image_format.lower()
img.save(temp_filename, format=image_format)
# Check the file size and adjust quality if needed
while os.path.getsize(temp_filename) < target_size_bytes:
# Increase dimensions or change compression quality
dimension += 1
pixel_data = np.random.randint(
0, 256, (dimension, dimension, 3), dtype=np.uint8
)
img = Image.fromarray(pixel_data)
img.save(temp_filename, format=image_format)
return img
class DontResizeIfLessThanMaxTest(unittest.TestCase):
def test_dont_resize_if_less_than_max(self):
max_image_size = 5
initial_size_mb = 0.1 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=max_image_size
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class ResizeLargeImageForModerationEndpoint(unittest.TestCase):
def test_resize_large_image_and_send_to_moderation_filter(self):
initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take
img = generate_random_image(initial_size_mb)
nsfw_flag, csam_flag = image_moderation_filter(img)
self.assertFalse(nsfw_flag)
self.assertFalse(nsfw_flag)
class DontResizeIfMaxImageSizeIsNone(unittest.TestCase):
def test_dont_resize_if_max_image_size_is_none(self):
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=None
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class OpenAIConversationDontResizeImage(unittest.TestCase):
def test(self):
conv = get_conv_template("chatgpt")
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
resized_img_bytes = base64.b64decode(resized_img)
new_image_size = check_byte_size_in_mb(resized_img_bytes)
self.assertEqual(previous_image_size, new_image_size)
class ClaudeConversationResizesCorrectly(unittest.TestCase):
def test(self):
conv = get_conv_template("claude-3-haiku-20240307")
initial_size_mb = 5 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
new_base64_image_size = check_byte_size_in_mb(resized_img)
new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img))
self.assertLess(new_image_bytes_size, previous_image_size)
self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb)
self.assertLessEqual(new_base64_image_size, 5)
|