Spaces:
Runtime error
Runtime error
File size: 4,518 Bytes
e4f9cbe 55dc3dd e4f9cbe 55dc3dd e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""A data loader standalone binary. This should only be run as a script to load a dataset.
To run the source loader as a binary directly:
poetry run python -m src.data_loader \
--dataset_name=movies_dataset \
--output_dir=./data/ \
--config_path=./datasets/the_movies_dataset.json
"""
import json
import math
import os
import pathlib
import uuid
from typing import Iterable, Optional, Union, cast
import click
import pandas as pd
from distributed import Client
from .data.dataset_utils import write_items_to_parquet
from .data.sources.default_sources import register_default_sources
from .data.sources.source import Source
from .data.sources.source_registry import resolve_source
from .schema import (
MANIFEST_FILENAME,
PARQUET_FILENAME_PREFIX,
UUID_COLUMN,
DataType,
Field,
Item,
Schema,
SourceManifest,
field,
)
from .tasks import TaskStepId, progress
from .utils import get_dataset_output_dir, log, open_file
def process_source(base_dir: Union[str, pathlib.Path],
namespace: str,
dataset_name: str,
source: Source,
task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
"""Process a source."""
output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)
source.setup()
source_schema = source.source_schema()
items = source.process()
# Add UUIDs and fix NaN in string columns.
items = normalize_items(items, source_schema.fields)
# Add progress.
if task_step_id is not None:
items = progress(
items,
task_step_id=task_step_id,
estimated_len=source_schema.num_items,
step_description=f'Reading from source {source.name}...')
# Filter out the `None`s after progress.
items = cast(Iterable[Item], (item for item in items if item is not None))
data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
filepath, num_items = write_items_to_parquet(
items=items,
output_dir=output_dir,
schema=data_schema,
filename_prefix=PARQUET_FILENAME_PREFIX,
shard_index=0,
num_shards=1)
filenames = [os.path.basename(filepath)]
manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
f.write(manifest.json(indent=2, exclude_none=True))
log(f'Manifest for dataset "{dataset_name}" written to {output_dir}')
return output_dir, num_items
def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
"""Sanitize items by removing NaNs and NaTs."""
replace_nan_fields = set([
field_name for field_name, field in fields.items()
if field.dtype == DataType.STRING or field.dtype == DataType.NULL
])
timestamp_fields = set(
[field_name for field_name, field in fields.items() if field.dtype == DataType.TIMESTAMP])
for item in items:
if item is None:
yield item
continue
# Add row uuid if it doesn't exist.
if UUID_COLUMN not in item:
item[UUID_COLUMN] = uuid.uuid4().hex
# Fix NaN string fields.
for name in replace_nan_fields:
item_value = item.get(name)
if item_value and not isinstance(item_value, str):
if math.isnan(item_value):
item[name] = None
else:
item[name] = str(item_value)
# Fix NaT (not a time) timestamp fields.
for name in timestamp_fields:
item_value = item.get(name)
if item_value and pd.isnull(item_value):
item[name] = None
yield item
@click.command()
@click.option(
'--output_dir',
required=True,
type=str,
help='The output directory to write the parquet files to.')
@click.option(
'--config_path',
required=True,
type=str,
help='The path to a json file describing the source configuration.')
@click.option(
'--dataset_name', required=True, type=str, help='The dataset name, used for binary mode only.')
@click.option(
'--namespace',
required=False,
default='local',
type=str,
help='The namespace to use. Defaults to "local".')
def main(output_dir: str, config_path: str, dataset_name: str, namespace: str) -> None:
"""Run the source loader as a binary."""
register_default_sources()
with open_file(config_path) as f:
# Parse the json file in a dict.
source_dict = json.load(f)
source = resolve_source(source_dict)
client = Client()
client.submit(process_source, output_dir, namespace, dataset_name, source).result()
if __name__ == '__main__':
main()
|