File size: 7,615 Bytes
019fb90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import math
import pickle
import tempfile
from functools import partial
from typing import Iterator, Optional, Union

import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import CommitOperationAdd, HfFileSystem
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema

spark = None

def set_session(session):
    global spark
    spark = session


def _read(iterator: Iterator[pa.RecordBatch], columns: Optional[list[str]], filters: Optional[Union[list[tuple], list[list[tuple]]]], **kwargs) -> Iterator[pa.RecordBatch]:
    for batch in iterator:
        paths = batch[0].to_pylist()
        ds = pq.ParquetDataset(paths, **kwargs)
        yield from ds._dataset.to_batches(columns=columns, filter=pq.filters_to_expression(filters) if filters else None)


def read_parquet(
    path: str,
    columns: Optional[list[str]] = None,
    filters: Optional[Union[list[tuple], list[list[tuple]]]] = None,
    **kwargs,
) -> DataFrame:
    """
    Loads Parquet files from Hugging Face using PyArrow, returning a PySPark `DataFrame`.

    It reads Parquet files in a distributed manner.

    Access private or gated repositories using `huggingface-cli login` or passing a token
    using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`

    Parameters
    ----------
    path : str
        Path to the file. Prefix with a protocol like `hf://` to read from Hugging Face.
        You can read from multiple files if you pass a globstring.
    columns : list, default None
        If not None, only these columns will be read from the file.
    filters : List[Tuple] or List[List[Tuple]], default None
        To filter out data.
        Filter syntax: [[(column, op, val), ...],...]
        where op is [==, =, >, >=, <, <=, !=, in, not in]
        The innermost tuples are transposed into a set of filters applied
        through an `AND` operation.
        The outer list combines these sets of filters through an `OR`
        operation.
        A single list of tuples can also be used, meaning that no `OR`
        operation between set of filters is to be conducted.

    **kwargs
        Any additional kwargs are passed to pyarrow.parquet.ParquetDataset.

    Returns
    -------
    DataFrame
        DataFrame based on parquet file.

    Examples
    --------
    >>> path = "hf://datasets/username/dataset/data.parquet"
    >>> pd.DataFrame({"foo": range(5), "bar": range(5, 10)}).to_parquet(path)
    >>> read_parquet(path).show()
    +---+---+
    |foo|bar|
    +---+---+
    |  0|  5|
    |  1|  6|
    |  2|  7|
    |  3|  8|
    |  4|  9|
    +---+---+
    >>> read_parquet(path, columns=["bar"]).show()
    +---+
    |bar|
    +---+
    |  5|
    |  6|
    |  7|
    |  8|
    |  9|
    +---+
    >>> sel = [("foo", ">", 2)]
    >>> read_parquet(path, filters=sel).show()
    +---+---+
    |foo|bar|
    +---+---+
    |  3|  8|
    |  4|  9|
    +---+---+
    """
    filesystem: HfFileSystem = kwargs.pop("filesystem") if "filesystem" in kwargs else HfFileSystem(**kwargs.pop("storage_options", {}))
    paths = filesystem.glob(path)
    if not paths:
        raise FileNotFoundError(f"Counldn't find any file at {path}")
    rdd = spark.sparkContext.parallelize([{"path": path} for path in paths], len(paths))
    df = spark.createDataFrame(rdd)
    arrow_schema = pq.read_schema(filesystem.open(paths[0]))
    schema = pa.schema([field for field in arrow_schema if (columns is None or field.name in columns)], metadata=arrow_schema.metadata)
    return df.mapInArrow(
        partial(_read, columns=columns, filters=filters, filesystem=filesystem, schema=arrow_schema, **kwargs),
        from_arrow_schema(schema),
    )


def _preupload(iterator: Iterator[pa.RecordBatch], path: str, schema: pa.Schema, filesystem: HfFileSystem, row_group_size: Optional[int] = None, **kwargs) -> Iterator[pa.RecordBatch]:
    resolved_path = filesystem.resolve_path(path)
    with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
        with pq.ParquetWriter(temp_file.name, schema=schema, **kwargs) as writer:
            for batch in iterator:
                writer.write_batch(batch, row_group_size=row_group_size)
        addition = CommitOperationAdd(path_in_repo=temp_file.name, path_or_fileobj=temp_file.name)
        filesystem._api.preupload_lfs_files(repo_id=resolved_path.repo_id, additions=[addition], repo_type=resolved_path.repo_type, revision=resolved_path.revision)
    yield pa.record_batch({"addition": [pickle.dumps(addition)]}, schema=pa.schema({"addition": pa.binary()}))


def _commit(iterator: Iterator[pa.RecordBatch], path: str, filesystem: HfFileSystem, max_operations_per_commit=50) -> Iterator[pa.RecordBatch]:
    resolved_path = filesystem.resolve_path(path)
    additions: list[CommitOperationAdd] = [pickle.loads(addition) for addition in pa.Table.from_batches(iterator, schema=pa.schema({"addition": pa.binary()}))[0].to_pylist()]
    num_commits = math.ceil(len(additions) / max_operations_per_commit)
    for shard_idx, addition in enumerate(additions):
        addition.path_in_repo = resolved_path.path_in_repo.replace("{shard_idx:05d}", f"{shard_idx:05d}")
    for i in range(0, num_commits):
        operations = additions[i * max_operations_per_commit : (i + 1) * max_operations_per_commit]
        commit_message = "Upload using PySpark" + (f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "")
        filesystem._api.create_commit(repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision, operations=operations, commit_message=commit_message)
        yield pa.record_batch({"path": [addition.path_in_repo for addition in operations]}, schema=pa.schema({"path": pa.string()}))


def write_parquet(df: DataFrame, path: str, **kwargs) -> None:
    """
    Write Parquet files to Hugging Face using PyArrow.

    It uploads Parquet files in a distributed manner in two steps:

    1. Preupload the Parquet files in parallel in a distributed banner
    2. Commit the preuploaded files

    Authenticate using `huggingface-cli login` or passing a token
    using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`

    Parameters
    ----------
    path : str
        Path of the file or directory. Prefix with a protocol like `hf://` to read from Hugging Face.
        It writes Parquet files in the form "part-xxxxx.parquet", or to a single file if `path ends with ".parquet".

    **kwargs
        Any additional kwargs are passed to pyarrow.parquet.ParquetWriter.

    Returns
    -------
    DataFrame
        DataFrame based on parquet file.

    Examples
    --------
    >>> spark.createDataFrame(pd.DataFrame({"foo": range(5), "bar": range(5, 10)}))
    >>> # Save to one file
    >>> write_parquet(df, "hf://datasets/username/dataset/data.parquet")
    >>> # OR save to a directory (possibly in many files)
    >>> write_parquet(df, "hf://datasets/username/dataset")
    """
    filesystem: HfFileSystem = kwargs.pop("filesystem", HfFileSystem(**kwargs.pop("storage_options", {})))
    if path.endswith(".parquet") or path.endswith(".pq"):
        df = df.coalesce(1)
    else:
        path += "/part-{shard_idx:05d}.parquet"
    df.mapInArrow(
        partial(_preupload, path=path, schema=to_arrow_schema(df.schema), filesystem=filesystem, **kwargs),
        from_arrow_schema(pa.schema({"addition": pa.binary()})),
    ).coalesce(1).mapInArrow(
        partial(_commit, path=path, filesystem=filesystem),
        from_arrow_schema(pa.schema({"path": pa.string()})),
    ).collect()