Spaces:
Runtime error
Runtime error
"""Interface for implementing a source.""" | |
import abc | |
from typing import ClassVar, Iterable, Optional | |
import numpy as np | |
import pandas as pd | |
import pyarrow as pa | |
from pydantic import BaseModel, validator | |
from ..schema import ( | |
Field, | |
ImageInfo, | |
Item, | |
Schema, | |
arrow_dtype_to_dtype, | |
arrow_schema_to_schema, | |
field, | |
) | |
class SourceSchema(BaseModel): | |
"""The schema of a source.""" | |
fields: dict[str, Field] | |
num_items: Optional[int] = None | |
class SourceProcessResult(BaseModel): | |
"""The result after processing all the shards of a source dataset.""" | |
filepaths: list[str] | |
data_schema: Schema | |
num_items: int | |
images: Optional[list[ImageInfo]] = None | |
class Source(abc.ABC, BaseModel): | |
"""Interface for sources to implement. A source processes a set of shards and writes files.""" | |
# ClassVars do not get serialized with pydantic. | |
name: ClassVar[str] | |
# The source_name will get populated in init automatically from the class name so it gets | |
# serialized and the source author doesn't have to define both the static property and the field. | |
source_name: Optional[str] = None | |
class Config: | |
underscore_attrs_are_private = True | |
def validate_source_name(cls, source_name: str) -> str: | |
"""Return the static name when the source_name name hasn't yet been set.""" | |
# When it's already been set from JSON, just return it. | |
if source_name: | |
return source_name | |
if 'name' not in cls.__dict__: | |
raise ValueError('Source attribute "name" must be defined.') | |
return cls.name | |
def source_schema(self) -> SourceSchema: | |
"""Return the source schema for this source. | |
Returns | |
A SourceSchema with | |
fields: mapping top-level columns to fields that describes the schema of the source. | |
num_items: the number of items in the source, used for progress. | |
""" | |
pass | |
def setup(self) -> None: | |
"""Prepare the source for processing. | |
This allows the source to do setup outside the constructor, but before its processed. This | |
avoids potentially expensive computation the pydantic model is deserialized. | |
""" | |
pass | |
def teardown(self) -> None: | |
"""Tears down the source after processing.""" | |
pass | |
def process(self) -> Iterable[Item]: | |
"""Process the source upload request. | |
Args: | |
task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the | |
progress of the task. | |
""" | |
pass | |
def schema_from_df(df: pd.DataFrame, index_colname: str) -> SourceSchema: | |
"""Create a source schema from a dataframe.""" | |
index_np_dtype = df.index.dtype | |
# String index dtypes are stored as objects. | |
if index_np_dtype == np.dtype(object): | |
index_np_dtype = np.dtype(str) | |
index_dtype = arrow_dtype_to_dtype(pa.from_numpy_dtype(index_np_dtype)) | |
schema = arrow_schema_to_schema(pa.Schema.from_pandas(df, preserve_index=False)) | |
return SourceSchema( | |
fields={ | |
**schema.fields, index_colname: field(dtype=index_dtype) | |
}, num_items=len(df)) | |
def normalize_column_name(name: str) -> str: | |
"""Normalize a column name.""" | |
return name | |
#return name.replace(' ', '_').replace(':', '_').replace('.', '_') | |