maomao517's picture
Upload folder using huggingface_hub
d10d213 verified
import json
import pytest
import torch
from PIL import Image
from lbm.data.mappers import (
KeyRenameMapper,
KeyRenameMapperConfig,
MapperWrapper,
RescaleMapper,
RescaleMapperConfig,
TorchvisionMapper,
TorchvisionMapperConfig,
)
class TestKeyRenameMapper:
@pytest.fixture()
def dummy_batch(self):
return {"image": 1, "text": 2, "label": "dummy_label"}
@pytest.fixture()
def mapper(self):
return KeyRenameMapper(
KeyRenameMapperConfig(
key_map={"image": "image_tensor", "text": "text_tensor"}
)
)
def test_mapper(self, mapper, dummy_batch):
output_data = mapper(dummy_batch)
assert output_data["image_tensor"] == 1
assert output_data["text_tensor"] == 2
assert output_data["label"] == "dummy_label"
assert "image" not in output_data
assert "text" not in output_data
class TestKeyRenameMapperWithCondition:
@pytest.fixture(params=[1, 2])
def dummy_batch(self, request):
return {"image": 1, "text": 2, "label": request.param}
@pytest.fixture(params=[{"image": "image_not_met", "text": "text_not_met"}, None])
def else_key_map(self, request):
return request.param
@pytest.fixture()
def mapper(self, else_key_map):
return KeyRenameMapper(
KeyRenameMapperConfig(
key_map={"image": "image_tensor", "text": "text_tensor"},
condition_key="label",
condition_fn=lambda x: x == 1,
else_key_map=else_key_map,
)
)
def test_mapper(self, mapper, dummy_batch, else_key_map):
output_data = mapper(dummy_batch)
if dummy_batch["label"] == 1:
assert output_data["image_tensor"] == 1
assert output_data["text_tensor"] == 2
assert output_data["label"] == 1
assert "image" not in output_data
assert "text" not in output_data
elif else_key_map is not None:
assert output_data["image_not_met"] == 1
assert output_data["text_not_met"] == 2
assert output_data["label"] == 2
assert "image" not in output_data
assert "text" not in output_data
else:
assert output_data["image"] == 1
assert output_data["text"] == 2
assert output_data["label"] == 2
assert "image_tensor" not in output_data
assert "text_tensor" not in output_data
class TestMapperWrapper:
@pytest.fixture()
def dummy_batch(self):
return {"image": 1, "text": 2, "label": "dummy_label"}
@pytest.fixture()
def mapper(self):
return MapperWrapper(
mappers=[
KeyRenameMapper(
KeyRenameMapperConfig(
key_map={"image": "image_tensor", "text": "text_tensor"}
)
),
KeyRenameMapper(
KeyRenameMapperConfig(
key_map={
"image_tensor": "image_array",
"text_tensor": "text_array",
}
)
),
]
)
def test_mapper(self, mapper, dummy_batch):
output_data = mapper(dummy_batch)
assert output_data["image_array"] == 1
assert output_data["text_array"] == 2
assert output_data["label"] == "dummy_label"
assert "image" not in output_data
assert "text" not in output_data
assert "image_tensor" not in output_data
assert "text_tensor" not in output_data
class TestTorchvisionMapper:
@pytest.fixture()
def dummy_batch(self):
return {
"image": torch.randn(
3,
256,
256,
),
"text": 2,
"label": "dummy_label",
}
@pytest.fixture()
def mapper(self):
return TorchvisionMapper(
TorchvisionMapperConfig(
key="image",
transforms=["CenterCrop", "ToPILImage"],
transforms_kwargs=[{"size": 224}, {}],
)
)
def test_mapper(self, mapper, dummy_batch):
output_data = mapper(dummy_batch)
assert output_data["image"].size == (224, 224)
assert isinstance(output_data["image"], Image.Image)
assert output_data["text"] == 2
assert output_data["label"] == "dummy_label"
@pytest.fixture()
def mapper_with_output_key(self):
return TorchvisionMapper(
TorchvisionMapperConfig(
key="image",
output_key="image_transformed",
transforms=["CenterCrop", "ToPILImage"],
transforms_kwargs=[{"size": 224}, {}],
)
)
def test_mapper(self, mapper_with_output_key, dummy_batch):
output_data = mapper_with_output_key(dummy_batch)
assert output_data["image_transformed"].size == (224, 224)
assert isinstance(output_data["image_transformed"], Image.Image)
assert isinstance(output_data["image"], torch.Tensor)
assert output_data["image"].size() == (3, 256, 256)
assert output_data["text"] == 2
assert output_data["label"] == "dummy_label"
class TestRescaleMapper:
@pytest.fixture()
def dummy_batch(self):
return {
"image": torch.rand(
3,
256,
256,
),
"text": 2,
"label": "dummy_label",
}
@pytest.fixture()
def mapper(self):
return RescaleMapper(
RescaleMapperConfig(
input_key="image",
output_key="image",
)
)
def test_mapper(self, mapper, dummy_batch):
output_data = mapper(dummy_batch)
assert torch.all(output_data["image"] <= 1)
assert torch.all(output_data["image"] >= -1)
assert output_data["text"] == 2
assert output_data["label"] == "dummy_label"