File size: 4,518 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
 
 
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""A data loader standalone binary. This should only be run as a script to load a dataset.

To run the source loader as a binary directly:

poetry run python -m src.data_loader \
  --dataset_name=movies_dataset \
  --output_dir=./data/ \
  --config_path=./datasets/the_movies_dataset.json
"""
import json
import math
import os
import pathlib
import uuid
from typing import Iterable, Optional, Union, cast

import click
import pandas as pd
from distributed import Client

from .data.dataset_utils import write_items_to_parquet
from .data.sources.default_sources import register_default_sources
from .data.sources.source import Source
from .data.sources.source_registry import resolve_source
from .schema import (
  MANIFEST_FILENAME,
  PARQUET_FILENAME_PREFIX,
  UUID_COLUMN,
  DataType,
  Field,
  Item,
  Schema,
  SourceManifest,
  field,
)
from .tasks import TaskStepId, progress
from .utils import get_dataset_output_dir, log, open_file


def process_source(base_dir: Union[str, pathlib.Path],
                   namespace: str,
                   dataset_name: str,
                   source: Source,
                   task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
  """Process a source."""
  output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)

  source.setup()
  source_schema = source.source_schema()
  items = source.process()

  # Add UUIDs and fix NaN in string columns.
  items = normalize_items(items, source_schema.fields)

  # Add progress.
  if task_step_id is not None:
    items = progress(
      items,
      task_step_id=task_step_id,
      estimated_len=source_schema.num_items,
      step_description=f'Reading from source {source.name}...')

  # Filter out the `None`s after progress.
  items = cast(Iterable[Item], (item for item in items if item is not None))

  data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
  filepath, num_items = write_items_to_parquet(
    items=items,
    output_dir=output_dir,
    schema=data_schema,
    filename_prefix=PARQUET_FILENAME_PREFIX,
    shard_index=0,
    num_shards=1)

  filenames = [os.path.basename(filepath)]
  manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
  with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
    f.write(manifest.json(indent=2, exclude_none=True))
  log(f'Manifest for dataset "{dataset_name}" written to {output_dir}')

  return output_dir, num_items


def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
  """Sanitize items by removing NaNs and NaTs."""
  replace_nan_fields = set([
    field_name for field_name, field in fields.items()
    if field.dtype == DataType.STRING or field.dtype == DataType.NULL
  ])
  timestamp_fields = set(
    [field_name for field_name, field in fields.items() if field.dtype == DataType.TIMESTAMP])
  for item in items:
    if item is None:
      yield item
      continue

    # Add row uuid if it doesn't exist.
    if UUID_COLUMN not in item:
      item[UUID_COLUMN] = uuid.uuid4().hex

    # Fix NaN string fields.
    for name in replace_nan_fields:
      item_value = item.get(name)
      if item_value and not isinstance(item_value, str):
        if math.isnan(item_value):
          item[name] = None
        else:
          item[name] = str(item_value)

    # Fix NaT (not a time) timestamp fields.
    for name in timestamp_fields:
      item_value = item.get(name)
      if item_value and pd.isnull(item_value):
        item[name] = None

    yield item


@click.command()
@click.option(
  '--output_dir',
  required=True,
  type=str,
  help='The output directory to write the parquet files to.')
@click.option(
  '--config_path',
  required=True,
  type=str,
  help='The path to a json file describing the source configuration.')
@click.option(
  '--dataset_name', required=True, type=str, help='The dataset name, used for binary mode only.')
@click.option(
  '--namespace',
  required=False,
  default='local',
  type=str,
  help='The namespace to use. Defaults to "local".')
def main(output_dir: str, config_path: str, dataset_name: str, namespace: str) -> None:
  """Run the source loader as a binary."""
  register_default_sources()

  with open_file(config_path) as f:
    # Parse the json file in a dict.
    source_dict = json.load(f)
    source = resolve_source(source_dict)

  client = Client()
  client.submit(process_source, output_dir, namespace, dataset_name, source).result()


if __name__ == '__main__':
  main()