"""Tests for item.py.""" import pyarrow as pa import pytest from .schema import ( PATH_WILDCARD, TEXT_SPAN_END_FEATURE, TEXT_SPAN_START_FEATURE, VALUE_KEY, DataType, Field, Item, arrow_schema_to_schema, child_item_from_column_path, column_paths_match, field, schema, schema_to_arrow_schema, ) NESTED_TEST_SCHEMA = schema({ 'person': { 'name': 'string', 'last_name': 'string_span', # Contains a double nested array of primitives. 'data': [['float32']], # Contains a value and children. 'description': field( 'string', fields={ 'toxicity': 'float32', 'sentences': [field('string_span', fields={'len': 'int32'})] }) }, 'addresses': [{ 'city': 'string', 'zipcode': 'int16', 'current': 'boolean', 'locations': [{ 'latitude': 'float16', 'longitude': 'float64' }] }], 'blob': 'binary' }) NESTED_TEST_ITEM: Item = { 'person': { 'name': 'Test Name', 'last_name': (5, 9) }, 'addresses': [{ 'city': 'a', 'zipcode': 1, 'current': False, 'locations': [{ 'latitude': 1.5, 'longitude': 3.8 }, { 'latitude': 2.9, 'longitude': 15.3 }], }, { 'city': 'b', 'zipcode': 2, 'current': True, 'locations': [{ 'latitude': 11.2, 'longitude': 20.1 }, { 'latitude': 30.1, 'longitude': 40.2 }], }] } def test_field_ctor_validation() -> None: with pytest.raises( ValueError, match='One of "fields", "repeated_field", or "dtype" should be defined'): Field() with pytest.raises(ValueError, match='Both "fields" and "repeated_field" should not be defined'): Field( fields={'name': Field(dtype=DataType.STRING)}, repeated_field=Field(dtype=DataType.INT32), ) with pytest.raises(ValueError, match=f'{VALUE_KEY} is a reserved field name'): Field(fields={VALUE_KEY: Field(dtype=DataType.STRING)},) def test_schema_leafs() -> None: expected = { ('addresses', PATH_WILDCARD, 'city'): Field(dtype=DataType.STRING), ('addresses', PATH_WILDCARD, 'current'): Field(dtype=DataType.BOOLEAN), ('addresses', PATH_WILDCARD, 'locations', PATH_WILDCARD, 'latitude'): Field(dtype=DataType.FLOAT16), ('addresses', PATH_WILDCARD, 'locations', PATH_WILDCARD, 'longitude'): Field(dtype=DataType.FLOAT64), ('addresses', PATH_WILDCARD, 'zipcode'): Field(dtype=DataType.INT16), ('blob',): Field(dtype=DataType.BINARY), ('person', 'name'): Field(dtype=DataType.STRING), ('person', 'last_name'): Field(dtype=DataType.STRING_SPAN), ('person', 'data', PATH_WILDCARD, PATH_WILDCARD): Field(dtype=DataType.FLOAT32), ('person', 'description'): Field( dtype=DataType.STRING, fields={ 'toxicity': Field(dtype=DataType.FLOAT32), 'sentences': Field( repeated_field=Field( dtype=DataType.STRING_SPAN, fields={'len': Field(dtype=DataType.INT32)})) }), ('person', 'description', 'toxicity'): Field(dtype=DataType.FLOAT32), ('person', 'description', 'sentences', PATH_WILDCARD): Field( fields={'len': Field(dtype=DataType.INT32)}, dtype=DataType.STRING_SPAN), ('person', 'description', 'sentences', PATH_WILDCARD, 'len'): Field(dtype=DataType.INT32), } assert NESTED_TEST_SCHEMA.leafs == expected def test_schema_to_arrow_schema() -> None: arrow_schema = schema_to_arrow_schema(NESTED_TEST_SCHEMA) assert arrow_schema == pa.schema({ 'person': pa.struct({ 'name': pa.string(), # The dtype for STRING_SPAN is implemented as a struct with a {start, end}. 'last_name': pa.struct({ VALUE_KEY: pa.struct({ TEXT_SPAN_START_FEATURE: pa.int32(), TEXT_SPAN_END_FEATURE: pa.int32(), }) }), 'data': pa.list_(pa.list_(pa.float32())), 'description': pa.struct({ 'toxicity': pa.float32(), 'sentences': pa.list_( pa.struct({ 'len': pa.int32(), VALUE_KEY: pa.struct({ TEXT_SPAN_START_FEATURE: pa.int32(), TEXT_SPAN_END_FEATURE: pa.int32(), }) })), VALUE_KEY: pa.string(), }) }), 'addresses': pa.list_( pa.struct({ 'city': pa.string(), 'zipcode': pa.int16(), 'current': pa.bool_(), 'locations': pa.list_(pa.struct({ 'latitude': pa.float16(), 'longitude': pa.float64() })), })), 'blob': pa.binary(), }) def test_arrow_schema_to_schema() -> None: arrow_schema = pa.schema({ 'person': pa.struct({ 'name': pa.string(), 'data': pa.list_(pa.list_(pa.float32())) }), 'addresses': pa.list_( pa.struct({ 'city': pa.string(), 'zipcode': pa.int16(), 'current': pa.bool_(), 'locations': pa.list_(pa.struct({ 'latitude': pa.float16(), 'longitude': pa.float64() })), })), 'blob': pa.binary(), }) expected_schema = schema({ 'person': { 'name': 'string', 'data': [['float32']] }, 'addresses': [{ 'city': 'string', 'zipcode': 'int16', 'current': 'boolean', 'locations': [{ 'latitude': 'float16', 'longitude': 'float64', }] }], 'blob': 'binary', }) assert arrow_schema_to_schema(arrow_schema) == expected_schema def test_simple_schema_str() -> None: assert str(schema({'person': 'string'})) == 'person: string' def test_child_item_from_column_path() -> None: assert child_item_from_column_path(NESTED_TEST_ITEM, ('addresses', '0', 'locations', '0', 'longitude')) == 3.8 assert child_item_from_column_path(NESTED_TEST_ITEM, ('addresses', '1', 'city')) == 'b' def test_child_item_from_column_path_raises_wildcard() -> None: with pytest.raises( ValueError, match='cannot be called with a path that contains a repeated wildcard'): child_item_from_column_path(NESTED_TEST_ITEM, ('addresses', PATH_WILDCARD, 'city')) def test_column_paths_match() -> None: assert column_paths_match(path_match=('person', 'name'), specific_path=('person', 'name')) is True assert column_paths_match( path_match=('person', 'name'), specific_path=('person', 'not_name')) is False # Wildcards work for structs. assert column_paths_match( path_match=(PATH_WILDCARD, 'name'), specific_path=('person', 'name')) is True assert column_paths_match( path_match=(PATH_WILDCARD, 'name'), specific_path=('person', 'not_name')) is False # Wildcards work for repeateds. assert column_paths_match( path_match=('person', PATH_WILDCARD, 'name'), specific_path=('person', '0', 'name')) is True assert column_paths_match( path_match=('person', PATH_WILDCARD, 'name'), specific_path=('person', '0', 'not_name')) is False # Sub-path matches always return False. assert column_paths_match(path_match=(PATH_WILDCARD,), specific_path=('person', 'name')) is False assert column_paths_match( path_match=( 'person', PATH_WILDCARD, ), specific_path=('person', '0', 'name')) is False def test_nested_schema_str() -> None: assert str(NESTED_TEST_SCHEMA) == """\ person: name: string last_name: string_span data: list( list( float32)) description: toxicity: float32 sentences: list( len: int32) addresses: list( city: string zipcode: int16 current: boolean locations: list( latitude: float16 longitude: float64)) blob: binary\ """