File size: 2,063 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Tests for data_loader.py."""

import os
import pathlib
import uuid
from typing import Iterable

from pytest_mock import MockerFixture
from typing_extensions import override

from .data.dataset_duckdb import read_source_manifest
from .data.dataset_utils import parquet_filename
from .data.sources.source import Source, SourceSchema
from .data_loader import process_source
from .schema import PARQUET_FILENAME_PREFIX, UUID_COLUMN, Item, SourceManifest, schema
from .test_utils import fake_uuid, read_items
from .utils import DATASETS_DIR_NAME


class TestSource(Source):
  """A test source."""
  name = 'test_source'

  @override
  def setup(self) -> None:
    pass

  @override
  def source_schema(self) -> SourceSchema:
    """Return the source schema."""
    return SourceSchema(fields=schema({'x': 'int64', 'y': 'string'}).fields, num_items=2)

  @override
  def process(self) -> Iterable[Item]:
    return [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]


def test_data_loader(tmp_path: pathlib.Path, mocker: MockerFixture) -> None:
  mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True)
  mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')]

  source = TestSource()
  setup_mock = mocker.spy(TestSource, 'setup')

  output_dir, num_items = process_source(tmp_path, 'test_namespace', 'test_dataset', source)

  assert setup_mock.call_count == 1

  assert output_dir == os.path.join(tmp_path, DATASETS_DIR_NAME, 'test_namespace', 'test_dataset')
  assert num_items == 2

  source_manifest = read_source_manifest(output_dir)

  assert source_manifest == SourceManifest(
    files=[parquet_filename(PARQUET_FILENAME_PREFIX, 0, 1)],
    data_schema=schema({
      # UUID_COLUMN is generated by the data loader.
      UUID_COLUMN: 'string',
      'x': 'int64',
      'y': 'string'
    }),
  )

  items = read_items(output_dir, source_manifest.files, source_manifest.data_schema)

  assert items == [{
    UUID_COLUMN: fake_uuid(b'1').hex,
    'x': 1,
    'y': 'ten'
  }, {
    UUID_COLUMN: fake_uuid(b'2').hex,
    'x': 2,
    'y': 'twenty'
  }]