File size: 1,874 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""CSV source."""
from typing import Iterable

import duckdb
import pandas as pd
from pydantic import Field as PydanticField
from typing_extensions import override

from ...schema import Item
from ...utils import download_http_files
from ..duckdb_utils import duckdb_setup
from .source import Source, SourceSchema, schema_from_df

ROW_ID_COLUMN = '__row_id__'


class JSONDataset(Source):
  """JSON data loader

  Supports both JSON and JSONL.

  JSON files can live locally as a filepath, or point to an external URL.
  """ # noqa: D415, D400
  name = 'json'

  filepaths: list[str] = PydanticField(description='A list of filepaths to JSON files.')

  _source_schema: SourceSchema
  _df: pd.DataFrame

  @override
  def setup(self) -> None:
    # Download JSON files to local cache if they are via HTTP to speed up duckdb.
    filepaths = download_http_files(self.filepaths)

    con = duckdb.connect(database=':memory:')

    # DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
    s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]

    # NOTE: We use duckdb here to increase parallelism for multiple files.
    self._df = con.execute(f"""
      {duckdb_setup(con)}
      SELECT * FROM read_json_auto(
        {s3_filepaths},
        IGNORE_ERRORS=true
      )
    """).df()

    # Create the source schema in prepare to share it between process and source_schema.
    self._source_schema = schema_from_df(self._df, ROW_ID_COLUMN)

  @override
  def source_schema(self) -> SourceSchema:
    """Return the source schema."""
    return self._source_schema

  @override
  def process(self) -> Iterable[Item]:
    """Process the source upload request."""
    cols = self._df.columns.tolist()
    yield from ({
      ROW_ID_COLUMN: idx,
      **dict(zip(cols, item_vals)),
    } for idx, *item_vals in self._df.itertuples())