Spaces:
Runtime error
Runtime error
"""A data loader standalone binary. This should only be run as a script to load a dataset. | |
To run the source loader as a binary directly: | |
poetry run python -m lilac.data_loader \ | |
--dataset_name=movies_dataset \ | |
--output_dir=./data/ \ | |
--config_path=./datasets/the_movies_dataset.json | |
""" | |
import os | |
import pathlib | |
import uuid | |
from typing import Iterable, Optional, Union | |
import pandas as pd | |
from .config import CONFIG_FILENAME, DatasetConfig | |
from .data.dataset import Dataset, default_settings | |
from .data.dataset_utils import write_items_to_parquet | |
from .db_manager import get_dataset | |
from .env import data_path | |
from .schema import ( | |
MANIFEST_FILENAME, | |
PARQUET_FILENAME_PREFIX, | |
ROWID, | |
Field, | |
Item, | |
Schema, | |
SourceManifest, | |
is_float, | |
) | |
from .tasks import TaskStepId, progress | |
from .utils import get_dataset_output_dir, log, open_file, to_yaml | |
def create_dataset(config: DatasetConfig) -> Dataset: | |
"""Load a dataset from a given source configuration.""" | |
process_source(data_path(), config) | |
return get_dataset(config.namespace, config.name) | |
def process_source(base_dir: Union[str, pathlib.Path], | |
config: DatasetConfig, | |
task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]: | |
"""Process a source.""" | |
output_dir = get_dataset_output_dir(base_dir, config.namespace, config.name) | |
config.source.setup() | |
source_schema = config.source.source_schema() | |
items = config.source.process() | |
# Add rowids and fix NaN in string columns. | |
items = normalize_items(items, source_schema.fields) | |
# Add progress. | |
items = progress( | |
items, | |
task_step_id=task_step_id, | |
estimated_len=source_schema.num_items, | |
step_description=f'Reading from source {config.source.name}...') | |
# Filter out the `None`s after progress. | |
items = (item for item in items if item is not None) | |
data_schema = Schema(fields=source_schema.fields.copy()) | |
filepath, num_items = write_items_to_parquet( | |
items=items, | |
output_dir=output_dir, | |
schema=data_schema, | |
filename_prefix=PARQUET_FILENAME_PREFIX, | |
shard_index=0, | |
num_shards=1) | |
filenames = [os.path.basename(filepath)] | |
manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None) | |
with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f: | |
f.write(manifest.json(indent=2, exclude_none=True)) | |
if not config.settings: | |
dataset = get_dataset(config.namespace, config.name) | |
config.settings = default_settings(dataset) | |
with open_file(os.path.join(output_dir, CONFIG_FILENAME), 'w') as f: | |
f.write(to_yaml(config.dict(exclude_defaults=True, exclude_none=True))) | |
log(f'Dataset "{config.name}" written to {output_dir}') | |
return output_dir, num_items | |
def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item: | |
"""Sanitize items by removing NaNs and NaTs.""" | |
replace_nan_fields = [ | |
field_name for field_name, field in fields.items() if field.dtype and not is_float(field.dtype) | |
] | |
for item in items: | |
if item is None: | |
yield item | |
continue | |
# Add rowid if it doesn't exist. | |
if ROWID not in item: | |
item[ROWID] = uuid.uuid4().hex | |
# Fix NaN values. | |
for field_name in replace_nan_fields: | |
item_value = item.get(field_name) | |
if item_value and pd.isna(item_value): | |
item[field_name] = None | |
yield item | |