nikhil_staging / src /data_loader.py
nsthorat's picture
Push
55dc3dd
raw
history blame
No virus
4.52 kB
"""A data loader standalone binary. This should only be run as a script to load a dataset.
To run the source loader as a binary directly:
poetry run python -m src.data_loader \
--dataset_name=movies_dataset \
--output_dir=./data/ \
--config_path=./datasets/the_movies_dataset.json
"""
import json
import math
import os
import pathlib
import uuid
from typing import Iterable, Optional, Union, cast
import click
import pandas as pd
from distributed import Client
from .data.dataset_utils import write_items_to_parquet
from .data.sources.default_sources import register_default_sources
from .data.sources.source import Source
from .data.sources.source_registry import resolve_source
from .schema import (
MANIFEST_FILENAME,
PARQUET_FILENAME_PREFIX,
UUID_COLUMN,
DataType,
Field,
Item,
Schema,
SourceManifest,
field,
)
from .tasks import TaskStepId, progress
from .utils import get_dataset_output_dir, log, open_file
def process_source(base_dir: Union[str, pathlib.Path],
namespace: str,
dataset_name: str,
source: Source,
task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
"""Process a source."""
output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)
source.setup()
source_schema = source.source_schema()
items = source.process()
# Add UUIDs and fix NaN in string columns.
items = normalize_items(items, source_schema.fields)
# Add progress.
if task_step_id is not None:
items = progress(
items,
task_step_id=task_step_id,
estimated_len=source_schema.num_items,
step_description=f'Reading from source {source.name}...')
# Filter out the `None`s after progress.
items = cast(Iterable[Item], (item for item in items if item is not None))
data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
filepath, num_items = write_items_to_parquet(
items=items,
output_dir=output_dir,
schema=data_schema,
filename_prefix=PARQUET_FILENAME_PREFIX,
shard_index=0,
num_shards=1)
filenames = [os.path.basename(filepath)]
manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
f.write(manifest.json(indent=2, exclude_none=True))
log(f'Manifest for dataset "{dataset_name}" written to {output_dir}')
return output_dir, num_items
def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
"""Sanitize items by removing NaNs and NaTs."""
replace_nan_fields = set([
field_name for field_name, field in fields.items()
if field.dtype == DataType.STRING or field.dtype == DataType.NULL
])
timestamp_fields = set(
[field_name for field_name, field in fields.items() if field.dtype == DataType.TIMESTAMP])
for item in items:
if item is None:
yield item
continue
# Add row uuid if it doesn't exist.
if UUID_COLUMN not in item:
item[UUID_COLUMN] = uuid.uuid4().hex
# Fix NaN string fields.
for name in replace_nan_fields:
item_value = item.get(name)
if item_value and not isinstance(item_value, str):
if math.isnan(item_value):
item[name] = None
else:
item[name] = str(item_value)
# Fix NaT (not a time) timestamp fields.
for name in timestamp_fields:
item_value = item.get(name)
if item_value and pd.isnull(item_value):
item[name] = None
yield item
@click.command()
@click.option(
'--output_dir',
required=True,
type=str,
help='The output directory to write the parquet files to.')
@click.option(
'--config_path',
required=True,
type=str,
help='The path to a json file describing the source configuration.')
@click.option(
'--dataset_name', required=True, type=str, help='The dataset name, used for binary mode only.')
@click.option(
'--namespace',
required=False,
default='local',
type=str,
help='The namespace to use. Defaults to "local".')
def main(output_dir: str, config_path: str, dataset_name: str, namespace: str) -> None:
"""Run the source loader as a binary."""
register_default_sources()
with open_file(config_path) as f:
# Parse the json file in a dict.
source_dict = json.load(f)
source = resolve_source(source_dict)
client = Client()
client.submit(process_source, output_dir, namespace, dataset_name, source).result()
if __name__ == '__main__':
main()