File size: 4,356 Bytes
e4f9cbe
 
 
 
56cce61
e4f9cbe
 
 
 
 
 
 
 
2226ee3
e4f9cbe
 
 
 
 
2226ee3
 
e4f9cbe
2226ee3
e4f9cbe
 
 
 
 
 
 
 
 
2226ee3
e4f9cbe
2226ee3
 
 
e4f9cbe
 
 
 
2226ee3
 
 
 
 
 
 
 
 
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2226ee3
 
 
 
 
e4f9cbe
 
2226ee3
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
2226ee3
e4f9cbe
 
 
 
 
 
2226ee3
 
 
e4f9cbe
 
 
 
 
 
 
 
 
2226ee3
 
 
 
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""A data loader standalone binary. This should only be run as a script to load a dataset.

To run the source loader as a binary directly:

poetry run python -m lilac.data_loader \
  --dataset_name=movies_dataset \
  --output_dir=./data/ \
  --config_path=./datasets/the_movies_dataset.json
"""
import json
import os
import pathlib
import uuid
from typing import Iterable, Optional, Union

import click
import pandas as pd
from distributed import Client

from .config import data_path
from .data.dataset import Dataset
from .data.dataset_utils import write_items_to_parquet
from .db_manager import get_dataset
from .schema import (
  MANIFEST_FILENAME,
  PARQUET_FILENAME_PREFIX,
  UUID_COLUMN,
  Field,
  Item,
  Schema,
  SourceManifest,
  field,
  is_float,
)
from .sources.default_sources import register_default_sources
from .sources.source import Source
from .sources.source_registry import resolve_source
from .tasks import TaskStepId, progress
from .utils import get_dataset_output_dir, log, open_file


def create_dataset(
  namespace: str,
  dataset_name: str,
  source_config: Source,
) -> Dataset:
  """Load a dataset from a given source configuration."""
  process_source(data_path(), namespace, dataset_name, source_config)
  return get_dataset(namespace, dataset_name)


def process_source(base_dir: Union[str, pathlib.Path],
                   namespace: str,
                   dataset_name: str,
                   source: Source,
                   task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
  """Process a source."""
  output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)

  source.setup()
  source_schema = source.source_schema()
  items = source.process()

  # Add UUIDs and fix NaN in string columns.
  items = normalize_items(items, source_schema.fields)

  # Add progress.
  items = progress(
    items,
    task_step_id=task_step_id,
    estimated_len=source_schema.num_items,
    step_description=f'Reading from source {source.name}...')

  # Filter out the `None`s after progress.
  items = (item for item in items if item is not None)

  data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
  filepath, num_items = write_items_to_parquet(
    items=items,
    output_dir=output_dir,
    schema=data_schema,
    filename_prefix=PARQUET_FILENAME_PREFIX,
    shard_index=0,
    num_shards=1)

  filenames = [os.path.basename(filepath)]
  manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
  with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
    f.write(manifest.json(indent=2, exclude_none=True))
  log(f'Dataset "{dataset_name}" written to {output_dir}')

  return output_dir, num_items


def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
  """Sanitize items by removing NaNs and NaTs."""
  replace_nan_fields = [
    field_name for field_name, field in fields.items() if field.dtype and not is_float(field.dtype)
  ]
  for item in items:
    if item is None:
      yield item
      continue

    # Add row uuid if it doesn't exist.
    if UUID_COLUMN not in item:
      item[UUID_COLUMN] = uuid.uuid4().hex

    # Fix NaN values.
    for field_name in replace_nan_fields:
      item_value = item.get(field_name)
      if item_value and pd.isna(item_value):
        item[field_name] = None

    yield item


@click.command()
@click.option(
  '--output_dir',
  required=True,
  type=str,
  help='The output directory to write the parquet files to.')
@click.option(
  '--config_path',
  required=True,
  type=str,
  help='The path to a json file describing the source configuration.')
@click.option(
  '--dataset_name', required=True, type=str, help='The dataset name, used for binary mode only.')
@click.option(
  '--namespace',
  required=False,
  default='local',
  type=str,
  help='The namespace to use. Defaults to "local".')
def main(output_dir: str, config_path: str, dataset_name: str, namespace: str) -> None:
  """Run the source loader as a binary."""
  register_default_sources()

  with open_file(config_path) as f:
    # Parse the json file in a dict.
    source_dict = json.load(f)
    source = resolve_source(source_dict)

  client = Client()
  client.submit(process_source, output_dir, namespace, dataset_name, source).result()


if __name__ == '__main__':
  main()