nikhil_staging / src /data /dataset_test_utils.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
3.9 kB
"""Tests utils of for dataset_test."""
import os
import pathlib
from typing import Optional, Type, cast
from typing_extensions import Protocol
from ..schema import (
MANIFEST_FILENAME,
PARQUET_FILENAME_PREFIX,
VALUE_KEY,
DataType,
Field,
Item,
Schema,
SourceManifest,
field,
)
from ..signals.signal import EMBEDDING_KEY
from ..utils import get_dataset_output_dir, open_file
from .dataset import Dataset
from .dataset_utils import is_primitive, lilac_span, write_items_to_parquet
TEST_NAMESPACE = 'test_namespace'
TEST_DATASET_NAME = 'test_dataset'
def _infer_dtype(value: Item) -> DataType:
if isinstance(value, str):
return DataType.STRING
elif isinstance(value, bool):
return DataType.BOOLEAN
elif isinstance(value, bytes):
return DataType.BINARY
elif isinstance(value, float):
return DataType.FLOAT32
elif isinstance(value, int):
return DataType.INT32
else:
raise ValueError(f'Cannot infer dtype of primitive value: {value}')
def _infer_field(item: Item) -> Field:
"""Infer the schema from the items."""
if isinstance(item, dict):
fields: dict[str, Field] = {}
for k, v in item.items():
fields[k] = _infer_field(cast(Item, v))
dtype = None
if VALUE_KEY in fields:
dtype = fields[VALUE_KEY].dtype
del fields[VALUE_KEY]
return Field(fields=fields, dtype=dtype)
elif is_primitive(item):
return Field(dtype=_infer_dtype(item))
elif isinstance(item, list):
return Field(repeated_field=_infer_field(item[0]))
else:
raise ValueError(f'Cannot infer schema of item: {item}')
def _infer_schema(items: list[Item]) -> Schema:
"""Infer the schema from the items."""
schema = Schema(fields={})
for item in items:
field = _infer_field(item)
if not field.fields:
raise ValueError(f'Invalid schema of item. Expected an object, but got: {item}')
schema.fields = {**schema.fields, **field.fields}
return schema
class TestDataMaker(Protocol):
"""A function that creates a test dataset."""
def __call__(self, items: list[Item], schema: Optional[Schema] = None) -> Dataset:
"""Create a test dataset."""
...
def make_dataset(dataset_cls: Type[Dataset],
tmp_path: pathlib.Path,
items: list[Item],
schema: Optional[Schema] = None) -> Dataset:
"""Create a test dataset."""
schema = schema or _infer_schema(items)
_write_items(tmp_path, TEST_DATASET_NAME, items, schema)
return dataset_cls(TEST_NAMESPACE, TEST_DATASET_NAME)
def _write_items(tmpdir: pathlib.Path, dataset_name: str, items: list[Item],
schema: Schema) -> None:
"""Write the items JSON to the dataset format: manifest.json and parquet files."""
source_dir = get_dataset_output_dir(str(tmpdir), TEST_NAMESPACE, dataset_name)
os.makedirs(source_dir)
simple_parquet_files, _ = write_items_to_parquet(
items, source_dir, schema, filename_prefix=PARQUET_FILENAME_PREFIX, shard_index=0, num_shards=1)
manifest = SourceManifest(files=[simple_parquet_files], data_schema=schema)
with open_file(os.path.join(source_dir, MANIFEST_FILENAME), 'w') as f:
f.write(manifest.json(indent=2, exclude_none=True))
def enriched_item(value: Optional[Item] = None, metadata: dict[str, Item] = {}) -> Item:
"""Wrap a value in a dict with the value key."""
return {VALUE_KEY: value, **metadata}
def enriched_embedding_span(start: int, end: int, metadata: dict[str, Item] = {}) -> Item:
"""Makes an item that represents an embedding span that was enriched with metadata."""
return lilac_span(start, end, {EMBEDDING_KEY: {VALUE_KEY: None, **metadata}})
def enriched_embedding_span_field(metadata: Optional[object] = {}) -> Field:
"""Makes a field that represents an embedding span that was enriched with metadata."""
return field('string_span', fields={EMBEDDING_KEY: field('embedding', fields=metadata)})