"""Item: an individual entry in the dataset.""" import csv import io from collections import deque from datetime import datetime from enum import Enum from typing import Any, Optional, Union, cast import numpy as np import pyarrow as pa from pydantic import BaseModel, StrictInt, StrictStr, validator from typing_extensions import TypedDict MANIFEST_FILENAME = 'manifest.json' PARQUET_FILENAME_PREFIX = 'data' # We choose `__rowid__` inspired by the standard `rowid` pseudocolumn in DBs: # https://docs.oracle.com/cd/B19306_01/server.102/b14200/pseudocolumns008.htm ROWID = '__rowid__' PATH_WILDCARD = '*' VALUE_KEY = '__value__' SIGNAL_METADATA_KEY = '__metadata__' TEXT_SPAN_START_FEATURE = 'start' TEXT_SPAN_END_FEATURE = 'end' EMBEDDING_KEY = 'embedding' # Python doesn't work with recursive types. These types provide some notion of type-safety. Scalar = Union[bool, datetime, int, float, str, bytes] Item = Any # Contains a string field name, a wildcard for repeateds, or a specific integer index for repeateds. # This path represents a path to a particular column. # Examples: # ['article', 'field'] represents {'article': {'field': VALUES}} # ['article', '*', 'field'] represents {'article': [{'field': VALUES}, {'field': VALUES}]} # ['article', '0', 'field'] represents {'article': {'field': VALUES}} PathTuple = tuple[StrictStr, ...] Path = Union[PathTuple, StrictStr] PathKeyedItem = tuple[Path, Item] # These fields are for for python only and not written to a schema. RichData = Union[str, bytes] VectorKey = tuple[Union[StrictStr, StrictInt], ...] PathKey = VectorKey class DataType(str, Enum): """Enum holding the dtype for a field.""" STRING = 'string' # Contains {start, end} offset integers with a reference_column. STRING_SPAN = 'string_span' BOOLEAN = 'boolean' # Ints. INT8 = 'int8' INT16 = 'int16' INT32 = 'int32' INT64 = 'int64' UINT8 = 'uint8' UINT16 = 'uint16' UINT32 = 'uint32' UINT64 = 'uint64' # Floats. FLOAT16 = 'float16' FLOAT32 = 'float32' FLOAT64 = 'float64' ### Time ### # Time of day (no time zone). TIME = 'time' # Calendar date (year, month, day), no time zone. DATE = 'date' # An "Instant" stored as number of microseconds (µs) since 1970-01-01 00:00:00+00 (UTC time zone). TIMESTAMP = 'timestamp' # Time span, stored as microseconds. INTERVAL = 'interval' BINARY = 'binary' EMBEDDING = 'embedding' NULL = 'null' def __repr__(self) -> str: return self.value class SignalInputType(str, Enum): """Enum holding the signal input type.""" TEXT = 'text' TEXT_EMBEDDING = 'text_embedding' IMAGE = 'image' def __repr__(self) -> str: return self.value SIGNAL_TYPE_TO_VALID_DTYPES: dict[SignalInputType, list[DataType]] = { SignalInputType.TEXT: [DataType.STRING, DataType.STRING_SPAN], SignalInputType.IMAGE: [DataType.BINARY], } def signal_type_supports_dtype(input_type: SignalInputType, dtype: DataType) -> bool: """Returns True if the signal compute type supports the dtype.""" return dtype in SIGNAL_TYPE_TO_VALID_DTYPES[input_type] Bin = tuple[str, Optional[Union[float, int]], Optional[Union[float, int]]] class Field(BaseModel): """Holds information for a field in the schema.""" repeated_field: Optional['Field'] = None fields: Optional[dict[str, 'Field']] = None dtype: Optional[DataType] = None # Defined as the serialized signal when this field is the root result of a signal. signal: Optional[dict[str, Any]] = None # Maps a named bin to a tuple of (start, end) values. bins: Optional[list[Bin]] = None categorical: Optional[bool] = None @validator('fields') def either_fields_or_repeated_field_is_defined( cls, fields: Optional[dict[str, 'Field']], values: dict[str, Any]) -> Optional[dict[str, 'Field']]: """Error if both `fields` and `repeated_fields` are defined.""" if not fields: return fields if values.get('repeated_field'): raise ValueError('Both "fields" and "repeated_field" should not be defined') if VALUE_KEY in fields: raise ValueError(f'{VALUE_KEY} is a reserved field name.') return fields @validator('dtype', always=True) def infer_default_dtype(cls, dtype: Optional[DataType], values: dict[str, Any]) -> Optional[DataType]: """Infers the default value for dtype if not explicitly provided.""" if dtype and values.get('repeated_field'): raise ValueError('dtype and repeated_field cannot both be defined.') if not values.get('repeated_field') and not values.get('fields') and not dtype: raise ValueError('One of "fields", "repeated_field", or "dtype" should be defined') return dtype @validator('bins') def validate_bins(cls, bins: list[Bin]) -> list[Bin]: """Validate the bins.""" if len(bins) < 2: raise ValueError('Please specify at least two bins.') _, first_start, _ = bins[0] if first_start is not None: raise ValueError('The first bin should have a `None` start value.') _, _, last_end = bins[-1] if last_end is not None: raise ValueError('The last bin should have a `None` end value.') for i, (_, start, _) in enumerate(bins): if i == 0: continue prev_bin = bins[i - 1] _, _, prev_end = prev_bin if start != prev_end: raise ValueError( f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.') return bins @validator('categorical') def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool: """Validate the categorical field.""" if categorical and is_float(values['dtype']): raise ValueError('Categorical fields cannot be float dtypes.') return categorical def __str__(self) -> str: return _str_field(self, indent=0) def __repr__(self) -> str: return f' {self.__class__.__name__}::{self.json(exclude_none=True, indent=2)}' class Schema(BaseModel): """Database schema.""" fields: dict[str, Field] # Cached leafs. _leafs: Optional[dict[PathTuple, Field]] = None # Cached flat list of all the fields. _all_fields: Optional[list[tuple[PathTuple, Field]]] = None class Config: arbitrary_types_allowed = True underscore_attrs_are_private = True @property def leafs(self) -> dict[PathTuple, Field]: """Return all the leaf fields in the schema. A leaf is defined as a node that contains a value. NOTE: Leafs may contain children. Leafs can be found as any node that has a dtype defined. """ if self._leafs: return self._leafs result: dict[PathTuple, Field] = {} q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))]) while q: path, field = q.popleft() if field.dtype: # Nodes with dtypes act as leafs. They also may have children. result[path] = field if field.fields: for name, child_field in field.fields.items(): child_path = (*path, name) q.append((child_path, child_field)) elif field.repeated_field: child_path = (*path, PATH_WILDCARD) q.append((child_path, field.repeated_field)) self._leafs = result return result @property def all_fields(self) -> list[tuple[PathTuple, Field]]: """Return all the fields in the schema as a flat list.""" if self._all_fields: return self._all_fields result: list[tuple[PathTuple, Field]] = [] q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))]) while q: path, field = q.popleft() if path: result.append((path, field)) if field.fields: for name, child_field in field.fields.items(): child_path = (*path, name) q.append((child_path, child_field)) elif field.repeated_field: child_path = (*path, PATH_WILDCARD) q.append((child_path, field.repeated_field)) self._all_fields = result return result def has_field(self, path: PathTuple) -> bool: """Returns if the field is found at the given path.""" field = cast(Field, self) for path_part in path: if field.fields: field = cast(Field, field.fields.get(path_part)) if not field: return False elif field.repeated_field: if path_part != PATH_WILDCARD: return False field = field.repeated_field else: return False return True def get_field(self, path: PathTuple) -> Field: """Returns the field at the given path.""" if path == (ROWID,): return Field(dtype=DataType.STRING) field = cast(Field, self) for name in path: if field.fields: if name not in field.fields: raise ValueError(f'Path {path} not found in schema') field = field.fields[name] elif field.repeated_field: if name != PATH_WILDCARD: raise ValueError(f'Invalid path {path}') field = field.repeated_field else: raise ValueError(f'Invalid path {path}') return field def __str__(self) -> str: return _str_fields(self.fields, indent=0) def __repr__(self) -> str: return self.json(exclude_none=True, indent=2) def schema(schema_like: object) -> Schema: """Parse a schema-like object to a Schema object.""" field = _parse_field_like(schema_like) if not field.fields: raise ValueError('Schema must have fields') return Schema(fields=field.fields) def field( dtype: Optional[Union[DataType, str]] = None, signal: Optional[dict] = None, fields: Optional[object] = None, bins: Optional[list[Bin]] = None, categorical: Optional[bool] = None, ) -> Field: """Parse a field-like object to a Field object.""" field = _parse_field_like(fields or {}, dtype) if signal: field.signal = signal if dtype: if isinstance(dtype, str): dtype = DataType(dtype) field.dtype = dtype if bins: field.bins = bins if categorical is not None: field.categorical = categorical return field class SpanVector(TypedDict): """A span with a vector.""" span: tuple[int, int] vector: np.ndarray def lilac_span(start: int, end: int, metadata: dict[str, Any] = {}) -> Item: """Creates a lilac span item, representing a pointer to a slice of text.""" return {VALUE_KEY: {TEXT_SPAN_START_FEATURE: start, TEXT_SPAN_END_FEATURE: end}, **metadata} def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> Item: """Creates a lilac embedding item, representing a vector with a pointer to a slice of text.""" return lilac_span(start, end, {EMBEDDING_KEY: embedding}) def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field: if isinstance(field_like, Field): return field_like elif isinstance(field_like, dict): fields: dict[str, Field] = {} for k, v in field_like.items(): fields[k] = _parse_field_like(v) if isinstance(dtype, str): dtype = DataType(dtype) return Field(fields=fields or None, dtype=dtype) elif isinstance(field_like, str): return Field(dtype=DataType(field_like)) elif isinstance(field_like, list): return Field(repeated_field=_parse_field_like(field_like[0], dtype=dtype)) else: raise ValueError(f'Cannot parse field like: {field_like}') def child_item_from_column_path(item: Item, path: Path) -> Item: """Return the last (child) item from a column path.""" child_item_value = item for path_part in path: if path_part == PATH_WILDCARD: raise ValueError( 'child_item_from_column_path cannot be called with a path that contains a repeated ' f'wildcard: "{path}"') # path_part can either be an integer or a string for a dictionary, both of which we can # directly index with. child_path = int(path_part) if path_part.isdigit() else path_part child_item_value = child_item_value[child_path] return child_item_value def column_paths_match(path_match: Path, specific_path: Path) -> bool: """Test whether two column paths match. Args: path_match: A column path that contains wildcards, and sub-paths. This path will be used for testing the second specific path. specific_path: A column path that specifically identifies an field. Returns Whether specific_path matches the path_match. This will only match when the paths are equal length. If a user wants to enrich everything with an array, they must use the path wildcard '*' in their patch match. """ if isinstance(path_match, str): path_match = (path_match,) if isinstance(specific_path, str): specific_path = (specific_path,) if len(path_match) != len(specific_path): return False for path_match_p, specific_path_p in zip(path_match, specific_path): if path_match_p == PATH_WILDCARD: continue if path_match_p != specific_path_p: return False return True def normalize_path(path: Path) -> PathTuple: """Normalizes a dot seperated path, but ignores dots inside quotes, like regular SQL. Examples - 'a.b.c' will be parsed as ('a', 'b', 'c'). - '"a.b".c' will be parsed as ('a.b', 'c'). - '"a".b.c' will be parsed as ('a', 'b', 'c'). """ if isinstance(path, str): return tuple(next(csv.reader(io.StringIO(path), delimiter='.'))) return path class ImageInfo(BaseModel): """Info about an individual image.""" path: Path class SourceManifest(BaseModel): """The manifest that describes the dataset run, including schema and parquet files.""" # List of a parquet filepaths storing the data. The paths can be relative to `manifest.json`. files: list[str] # The data schema. data_schema: Schema # Image information for the dataset. images: Optional[list[ImageInfo]] = None def _str_fields(fields: dict[str, Field], indent: int) -> str: prefix = ' ' * indent out: list[str] = [] for name, field in fields.items(): out.append(f'{prefix}{name}:{_str_field(field, indent=indent + 2)}') return '\n'.join(out) def _str_field(field: Field, indent: int) -> str: if field.fields: prefix = '\n' if indent > 0 else '' return f'{prefix}{_str_fields(field.fields, indent)}' if field.repeated_field: return f' list({_str_field(field.repeated_field, indent)})' return f' {cast(DataType, field.dtype)}' def dtype_to_arrow_schema(dtype: DataType) -> Union[pa.Schema, pa.DataType]: """Convert the dtype to an arrow dtype.""" if dtype == DataType.STRING: return pa.string() elif dtype == DataType.BOOLEAN: return pa.bool_() elif dtype == DataType.FLOAT16: return pa.float16() elif dtype == DataType.FLOAT32: return pa.float32() elif dtype == DataType.FLOAT64: return pa.float64() elif dtype == DataType.INT8: return pa.int8() elif dtype == DataType.INT16: return pa.int16() elif dtype == DataType.INT32: return pa.int32() elif dtype == DataType.INT64: return pa.int64() elif dtype == DataType.UINT8: return pa.uint8() elif dtype == DataType.UINT16: return pa.uint16() elif dtype == DataType.UINT32: return pa.uint32() elif dtype == DataType.UINT64: return pa.uint64() elif dtype == DataType.BINARY: return pa.binary() elif dtype == DataType.TIME: return pa.time64() elif dtype == DataType.DATE: return pa.date64() elif dtype == DataType.TIMESTAMP: return pa.timestamp('us') elif dtype == DataType.INTERVAL: return pa.duration('us') elif dtype == DataType.EMBEDDING: # We reserve an empty column for embeddings in parquet files so they can be queried. # The values are *not* filled out. If parquet and duckdb support embeddings in the future, we # can set this dtype to the relevant pyarrow type. return pa.null() elif dtype == DataType.STRING_SPAN: return pa.struct({ VALUE_KEY: pa.struct({ TEXT_SPAN_START_FEATURE: pa.int32(), TEXT_SPAN_END_FEATURE: pa.int32() }) }) elif dtype == DataType.NULL: return pa.null() else: raise ValueError(f'Can not convert dtype "{dtype}" to arrow dtype') def schema_to_arrow_schema(schema: Union[Schema, Field]) -> pa.Schema: """Convert our schema to arrow schema.""" arrow_schema = cast(pa.Schema, _schema_to_arrow_schema_impl(schema)) arrow_fields = {field.name: field.type for field in arrow_schema} return pa.schema(arrow_fields) def _schema_to_arrow_schema_impl(schema: Union[Schema, Field]) -> Union[pa.Schema, pa.DataType]: """Convert a schema to an apache arrow schema.""" if schema.fields: arrow_fields: dict[str, Union[pa.Schema, pa.DataType]] = {} for name, field in schema.fields.items(): if name == ROWID: arrow_schema = dtype_to_arrow_schema(cast(DataType, field.dtype)) else: arrow_schema = _schema_to_arrow_schema_impl(field) arrow_fields[name] = arrow_schema if isinstance(schema, Schema): # Top-level schemas do not have __value__ fields. return pa.schema(arrow_fields) else: # When nodes have both dtype and children, we add __value__ alongside the fields. if schema.dtype: value_schema = dtype_to_arrow_schema(schema.dtype) if schema.dtype == DataType.STRING_SPAN: value_schema = value_schema[VALUE_KEY].type arrow_fields[VALUE_KEY] = value_schema return pa.struct(arrow_fields) field = cast(Field, schema) if field.repeated_field: return pa.list_(_schema_to_arrow_schema_impl(field.repeated_field)) return dtype_to_arrow_schema(cast(DataType, field.dtype)) def arrow_dtype_to_dtype(arrow_dtype: pa.DataType) -> DataType: """Convert arrow dtype to our dtype.""" # Ints. if arrow_dtype == pa.int8(): return DataType.INT8 elif arrow_dtype == pa.int16(): return DataType.INT16 elif arrow_dtype == pa.int32(): return DataType.INT32 elif arrow_dtype == pa.int64(): return DataType.INT64 elif arrow_dtype == pa.uint8(): return DataType.UINT8 elif arrow_dtype == pa.uint16(): return DataType.UINT16 elif arrow_dtype == pa.uint32(): return DataType.UINT32 elif arrow_dtype == pa.uint64(): return DataType.UINT64 # Floats. elif arrow_dtype == pa.float16(): return DataType.FLOAT16 elif arrow_dtype == pa.float32(): return DataType.FLOAT32 elif arrow_dtype == pa.float64(): return DataType.FLOAT64 # Time. elif pa.types.is_time(arrow_dtype): return DataType.TIME elif pa.types.is_date(arrow_dtype): return DataType.DATE elif pa.types.is_timestamp(arrow_dtype): return DataType.TIMESTAMP elif pa.types.is_duration(arrow_dtype): return DataType.INTERVAL # Others. elif arrow_dtype == pa.string(): return DataType.STRING elif pa.types.is_binary(arrow_dtype) or pa.types.is_fixed_size_binary(arrow_dtype): return DataType.BINARY elif pa.types.is_boolean(arrow_dtype): return DataType.BOOLEAN elif arrow_dtype == pa.null(): return DataType.NULL else: raise ValueError(f'Can not convert arrow dtype "{arrow_dtype}" to our dtype') def arrow_schema_to_schema(schema: pa.Schema) -> Schema: """Convert arrow schema to our schema.""" # TODO(nsthorat): Change this implementation to allow more complicated reading of arrow schemas # into our schema by inferring values when {__value__: value} is present in the pyarrow schema. # This isn't necessary today as this util is only needed by sources which do not have data in the # lilac format. return cast(Schema, _arrow_schema_to_schema_impl(schema)) def _arrow_schema_to_schema_impl(schema: Union[pa.Schema, pa.DataType]) -> Union[Schema, Field]: """Convert an apache arrow schema to our schema.""" if isinstance(schema, (pa.Schema, pa.StructType)): fields: dict[str, Field] = { field.name: cast(Field, _arrow_schema_to_schema_impl(field.type)) for field in schema } return Schema(fields=fields) if isinstance(schema, pa.Schema) else Field(fields=fields) elif isinstance(schema, pa.ListType): return Field(repeated_field=cast(Field, _arrow_schema_to_schema_impl(schema.value_field.type))) else: return Field(dtype=arrow_dtype_to_dtype(schema)) def is_float(dtype: DataType) -> bool: """Check if a dtype is a float dtype.""" return dtype in [DataType.FLOAT16, DataType.FLOAT32, DataType.FLOAT64] def is_integer(dtype: DataType) -> bool: """Check if a dtype is an integer dtype.""" return dtype in [ DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64 ] def is_temporal(dtype: DataType) -> bool: """Check if a dtype is a temporal dtype.""" return dtype in [DataType.TIME, DataType.DATE, DataType.TIMESTAMP, DataType.INTERVAL] def is_ordinal(dtype: DataType) -> bool: """Check if a dtype is an ordinal dtype.""" return is_float(dtype) or is_integer(dtype) or is_temporal(dtype)