Spaces:
Runtime error
Runtime error
"""Interface for implementing a source.""" | |
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Type, Union | |
import numpy as np | |
import pandas as pd | |
import pyarrow as pa | |
from pydantic import BaseModel | |
if TYPE_CHECKING: | |
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny | |
from ..schema import ( | |
Field, | |
ImageInfo, | |
Item, | |
Schema, | |
arrow_dtype_to_dtype, | |
arrow_schema_to_schema, | |
field, | |
) | |
class SourceSchema(BaseModel): | |
"""The schema of a source.""" | |
fields: dict[str, Field] | |
num_items: Optional[int] = None | |
class SourceProcessResult(BaseModel): | |
"""The result after processing all the shards of a source dataset.""" | |
filepaths: list[str] | |
data_schema: Schema | |
num_items: int | |
images: Optional[list[ImageInfo]] = None | |
class Source(BaseModel): | |
"""Interface for sources to implement. A source processes a set of shards and writes files.""" | |
# ClassVars do not get serialized with pydantic. | |
name: ClassVar[str] | |
def dict( | |
self, | |
*, | |
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
by_alias: bool = False, | |
skip_defaults: Optional[bool] = None, | |
exclude_unset: bool = False, | |
exclude_defaults: bool = False, | |
exclude_none: bool = False, | |
) -> dict[str, Any]: | |
"""Override the default dict method to add `source_name`.""" | |
res = super().dict( | |
include=include, | |
exclude=exclude, | |
by_alias=by_alias, | |
skip_defaults=skip_defaults, | |
exclude_unset=exclude_unset, | |
exclude_defaults=exclude_defaults, | |
exclude_none=exclude_none) | |
res['source_name'] = self.name | |
return res | |
class Config: | |
underscore_attrs_are_private = True | |
def schema_extra(schema: dict[str, Any], source: Type['Source']) -> None: | |
"""Add the title to the schema from the display name and name. | |
Pydantic defaults this to the class name. | |
""" | |
signal_prop: dict[str, Any] | |
if hasattr(source, 'name'): | |
signal_prop = {'enum': [source.name]} | |
else: | |
signal_prop = {'type': 'string'} | |
schema['properties'] = {'source_name': signal_prop, **schema['properties']} | |
if 'required' not in schema: | |
schema['required'] = [] | |
schema['required'].append('source_name') | |
def source_schema(self) -> SourceSchema: | |
"""Return the source schema for this source. | |
Returns | |
A SourceSchema with | |
fields: mapping top-level columns to fields that describes the schema of the source. | |
num_items: the number of items in the source, used for progress. | |
""" | |
raise NotImplementedError | |
def setup(self) -> None: | |
"""Prepare the source for processing. | |
This allows the source to do setup outside the constructor, but before its processed. This | |
avoids potentially expensive computation the pydantic model is deserialized. | |
""" | |
pass | |
def teardown(self) -> None: | |
"""Tears down the source after processing.""" | |
pass | |
def process(self) -> Iterable[Item]: | |
"""Process the source upload request. | |
Args: | |
task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the | |
progress of the task. | |
""" | |
raise NotImplementedError | |
def schema_from_df(df: pd.DataFrame, index_colname: str) -> SourceSchema: | |
"""Create a source schema from a dataframe.""" | |
index_np_dtype = df.index.dtype | |
# String index dtypes are stored as objects. | |
if index_np_dtype == np.dtype(object): | |
index_np_dtype = np.dtype(str) | |
index_dtype = arrow_dtype_to_dtype(pa.from_numpy_dtype(index_np_dtype)) | |
schema = arrow_schema_to_schema(pa.Schema.from_pandas(df, preserve_index=False)) | |
return SourceSchema( | |
fields={ | |
**schema.fields, index_colname: field(dtype=index_dtype) | |
}, num_items=len(df)) | |
def normalize_column_name(name: str) -> str: | |
"""Normalize a column name.""" | |
return name | |
#return name.replace(' ', '_').replace(':', '_').replace('.', '_') | |