Spaces:
Runtime error
Runtime error
"""A data loader standalone binary. This should only be run as a script to load a dataset. | |
To run the source loader as a binary directly: | |
poetry run python -m lilac.data_loader \ | |
--dataset_name=movies_dataset \ | |
--output_dir=./data/ \ | |
--config_path=./datasets/the_movies_dataset.json | |
""" | |
import json | |
import math | |
import os | |
import pathlib | |
import uuid | |
from typing import Iterable, Optional, Union, cast | |
import click | |
import pandas as pd | |
from distributed import Client | |
from .data.dataset_utils import write_items_to_parquet | |
from .data.sources.default_sources import register_default_sources | |
from .data.sources.source import Source | |
from .data.sources.source_registry import resolve_source | |
from .schema import ( | |
MANIFEST_FILENAME, | |
PARQUET_FILENAME_PREFIX, | |
UUID_COLUMN, | |
DataType, | |
Field, | |
Item, | |
Schema, | |
SourceManifest, | |
field, | |
) | |
from .tasks import TaskStepId, progress | |
from .utils import get_dataset_output_dir, log, open_file | |
def process_source(base_dir: Union[str, pathlib.Path], | |
namespace: str, | |
dataset_name: str, | |
source: Source, | |
task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]: | |
"""Process a source.""" | |
output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name) | |
source.setup() | |
source_schema = source.source_schema() | |
items = source.process() | |
# Add UUIDs and fix NaN in string columns. | |
items = normalize_items(items, source_schema.fields) | |
# Add progress. | |
if task_step_id is not None: | |
items = progress( | |
items, | |
task_step_id=task_step_id, | |
estimated_len=source_schema.num_items, | |
step_description=f'Reading from source {source.name}...') | |
# Filter out the `None`s after progress. | |
items = cast(Iterable[Item], (item for item in items if item is not None)) | |
data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')}) | |
filepath, num_items = write_items_to_parquet( | |
items=items, | |
output_dir=output_dir, | |
schema=data_schema, | |
filename_prefix=PARQUET_FILENAME_PREFIX, | |
shard_index=0, | |
num_shards=1) | |
filenames = [os.path.basename(filepath)] | |
manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None) | |
with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f: | |
f.write(manifest.json(indent=2, exclude_none=True)) | |
log(f'Manifest for dataset "{dataset_name}" written to {output_dir}') | |
return output_dir, num_items | |
def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item: | |
"""Sanitize items by removing NaNs and NaTs.""" | |
replace_nan_fields = set([ | |
field_name for field_name, field in fields.items() | |
if field.dtype == DataType.STRING or field.dtype == DataType.NULL | |
]) | |
timestamp_fields = set( | |
[field_name for field_name, field in fields.items() if field.dtype == DataType.TIMESTAMP]) | |
for item in items: | |
if item is None: | |
yield item | |
continue | |
# Add row uuid if it doesn't exist. | |
if UUID_COLUMN not in item: | |
item[UUID_COLUMN] = uuid.uuid4().hex | |
# Fix NaN string fields. | |
for name in replace_nan_fields: | |
item_value = item.get(name) | |
if item_value and not isinstance(item_value, str): | |
if math.isnan(item_value): | |
item[name] = None | |
else: | |
item[name] = str(item_value) | |
# Fix NaT (not a time) timestamp fields. | |
for name in timestamp_fields: | |
item_value = item.get(name) | |
if item_value and pd.isnull(item_value): | |
item[name] = None | |
yield item | |
def main(output_dir: str, config_path: str, dataset_name: str, namespace: str) -> None: | |
"""Run the source loader as a binary.""" | |
register_default_sources() | |
with open_file(config_path) as f: | |
# Parse the json file in a dict. | |
source_dict = json.load(f) | |
source = resolve_source(source_dict) | |
client = Client() | |
client.submit(process_source, output_dir, namespace, dataset_name, source).result() | |
if __name__ == '__main__': | |
main() | |