|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import numpy as np |
|
|
|
from transformers import is_torch_available, is_vision_available |
|
from transformers.processing_utils import _validate_images_text_input_order |
|
from transformers.testing_utils import require_torch, require_vision |
|
|
|
|
|
if is_vision_available(): |
|
import PIL |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
@require_vision |
|
class ProcessingUtilTester(unittest.TestCase): |
|
def test_validate_images_text_input_order(self): |
|
|
|
images = PIL.Image.new("RGB", (224, 224)) |
|
text = "text" |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = np.random.rand(224, 224, 3) |
|
text = ["text1", "text2"] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertTrue(np.array_equal(valid_images, images)) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertTrue(np.array_equal(valid_images, images)) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] |
|
text = [["text1", "text2, text3"], ["text3", "text4"]] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] |
|
text = ["text1", "text2"] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertTrue(np.array_equal(valid_images[0], images[0])) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertTrue(np.array_equal(valid_images[0], images[0])) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = ["https://url1", "https://url2"] |
|
text = ["text1", "text2"] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] |
|
text = ["text1", "text2"] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = [ |
|
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], |
|
[PIL.Image.new("RGB", (224, 224))], |
|
] |
|
text = [["text1", "text2, text3"], ["text3", "text4"]] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(valid_images, images) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = None |
|
text = "text" |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(images, None) |
|
self.assertEqual(text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(images, None) |
|
self.assertEqual(text, text) |
|
|
|
|
|
images = PIL.Image.new("RGB", (224, 224)) |
|
text = None |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertEqual(images, images) |
|
self.assertEqual(text, None) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertEqual(images, images) |
|
self.assertEqual(text, None) |
|
|
|
|
|
images = "text" |
|
text = "text" |
|
with self.assertRaises(ValueError): |
|
_validate_images_text_input_order(images=images, text=text) |
|
|
|
@require_torch |
|
def test_validate_images_text_input_order_torch(self): |
|
|
|
images = torch.rand(224, 224, 3) |
|
text = "text" |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertTrue(torch.equal(valid_images, images)) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertTrue(torch.equal(valid_images, images)) |
|
self.assertEqual(valid_text, text) |
|
|
|
|
|
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] |
|
text = ["text1", "text2"] |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) |
|
self.assertTrue(torch.equal(valid_images[0], images[0])) |
|
self.assertEqual(valid_text, text) |
|
|
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) |
|
self.assertTrue(torch.equal(valid_images[0], images[0])) |
|
self.assertEqual(valid_text, text) |
|
|