File size: 2,172 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Tests for dataset utils."""
from ..schema import PathTuple
from .dataset_utils import count_primitives, flatten, unflatten, wrap_in_dicts


def test_flatten() -> None:
  a = [[1, 2], [[3]], [4, 5, 5]]
  result = list(flatten(a))
  assert result == [1, 2, 3, 4, 5, 5]


def test_flatten_primitive() -> None:
  result = list(flatten('hello'))
  assert result == ['hello']


def test_unflatten() -> None:
  a = [[1, 2], [[3]], [4, 5, 5]]
  flat_a = list(flatten(a))
  result = unflatten(flat_a, a)
  assert result == [[1, 2], [[3]], [4, 5, 5]]


def test_count_nested() -> None:
  a = [[1, 2], [[3]], [4, 5, 6]]
  assert 6 == count_primitives(a)


def test_wrap_in_dicts_with_spec_of_one_repeated() -> None:
  a = [[1, 2], [3], [4, 5, 5]]
  spec: list[PathTuple] = [('a', 'b', 'c'), ('d',)]  # Corresponds to a.b.c.*.d.
  result = wrap_in_dicts(a, spec)
  assert result == [{
    'a': {
      'b': {
        'c': [{
          'd': 1
        }, {
          'd': 2
        }]
      }
    }
  }, {
    'a': {
      'b': {
        'c': [{
          'd': 3
        }]
      }
    }
  }, {
    'a': {
      'b': {
        'c': [{
          'd': 4
        }, {
          'd': 5
        }, {
          'd': 5
        }]
      }
    }
  }]


def test_wrap_in_dicts_with_spec_of_double_repeated() -> None:
  a = [[[1, 2], [3, 4, 5]], [[6]], [[7], [8], [9, 10]]]
  spec: list[PathTuple] = [('a', 'b'), tuple(), ('c',)]  # Corresponds to a.b.*.*.c.
  result = wrap_in_dicts(a, spec)
  assert result == [{
    'a': {
      'b': [[{
        'c': 1
      }, {
        'c': 2
      }], [{
        'c': 3
      }, {
        'c': 4
      }, {
        'c': 5
      }]]
    }
  }, {
    'a': {
      'b': [[{
        'c': 6
      }]]
    }
  }, {
    'a': {
      'b': [[{
        'c': 7
      }], [{
        'c': 8
      }], [{
        'c': 9
      }, {
        'c': 10
      }]]
    }
  }]


def test_unflatten_primitive() -> None:
  original = 'hello'
  result = unflatten(['hello'], original)
  assert result == 'hello'


def test_unflatten_primitive_list() -> None:
  original = ['hello', 'world']
  result = unflatten(['hello', 'world'], original)
  assert result == ['hello', 'world']