File size: 7,618 Bytes
019fb90 88ad24a 019fb90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import math
import pickle
import tempfile
from functools import partial
from typing import Iterator, Optional, Union
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import CommitOperationAdd, HfFileSystem
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
spark = None
def set_session(session):
global spark
spark = session
def _read(iterator: Iterator[pa.RecordBatch], columns: Optional[list[str]], filters: Optional[Union[list[tuple], list[list[tuple]]]], **kwargs) -> Iterator[pa.RecordBatch]:
for batch in iterator:
paths = batch[0].to_pylist()
ds = pq.ParquetDataset(paths, **kwargs)
yield from ds._dataset.to_batches(columns=columns, filter=pq.filters_to_expression(filters) if filters else None)
def read_parquet(
path: str,
columns: Optional[list[str]] = None,
filters: Optional[Union[list[tuple], list[list[tuple]]]] = None,
) -> DataFrame:
Loads Parquet files from Hugging Face using PyArrow, returning a PySPark `DataFrame`.
It reads Parquet files in a distributed manner.
Access private or gated repositories using `huggingface-cli login` or passing a token
using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`
path : str
Path to the file. Prefix with a protocol like `hf://` to read from Hugging Face.
You can read from multiple files if you pass a globstring.
columns : list, default None
If not None, only these columns will be read from the file.
filters : List[Tuple] or List[List[Tuple]], default None
To filter out data.
Filter syntax: [[(column, op, val), ...],...]
where op is [==, =, >, >=, <, <=, !=, in, not in]
The innermost tuples are transposed into a set of filters applied
through an `AND` operation.
The outer list combines these sets of filters through an `OR`
A single list of tuples can also be used, meaning that no `OR`
operation between set of filters is to be conducted.
Any additional kwargs are passed to pyarrow.parquet.ParquetDataset.
DataFrame based on parquet file.
>>> path = "hf://datasets/username/dataset/data.parquet"
>>> pd.DataFrame({"foo": range(5), "bar": range(5, 10)}).to_parquet(path)
>>> read_parquet(path).show()
| 0| 5|
| 1| 6|
| 2| 7|
| 3| 8|
| 4| 9|
>>> read_parquet(path, columns=["bar"]).show()
| 5|
| 6|
| 7|
| 8|
| 9|
>>> sel = [("foo", ">", 2)]
>>> read_parquet(path, filters=sel).show()
| 3| 8|
| 4| 9|
filesystem: HfFileSystem = kwargs.pop("filesystem") if "filesystem" in kwargs else HfFileSystem(**kwargs.pop("storage_options", {}))
paths = filesystem.glob(path)
if not paths:
raise FileNotFoundError(f"Counldn't find any file at {path}")
rdd = spark.sparkContext.parallelize([{"path": path} for path in paths], len(paths))
df = spark.createDataFrame(rdd)
arrow_schema = pq.read_schema([0]))
schema = pa.schema([field for field in arrow_schema if (columns is None or in columns)], metadata=arrow_schema.metadata)
return df.mapInArrow(
partial(_read, columns=columns, filters=filters, filesystem=filesystem, schema=arrow_schema, **kwargs),
def _preupload(iterator: Iterator[pa.RecordBatch], path: str, schema: pa.Schema, filesystem: HfFileSystem, row_group_size: Optional[int] = None, **kwargs) -> Iterator[pa.RecordBatch]:
resolved_path = filesystem.resolve_path(path)
with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
with pq.ParquetWriter(, schema=schema, **kwargs) as writer:
for batch in iterator:
writer.write_batch(batch, row_group_size=row_group_size)
addition = CommitOperationAdd(,
filesystem._api.preupload_lfs_files(repo_id=resolved_path.repo_id, additions=[addition], repo_type=resolved_path.repo_type, revision=resolved_path.revision)
yield pa.record_batch({"addition": [pickle.dumps(addition)]}, schema=pa.schema({"addition": pa.binary()}))
def _commit(iterator: Iterator[pa.RecordBatch], path: str, filesystem: HfFileSystem, max_operations_per_commit=50) -> Iterator[pa.RecordBatch]:
resolved_path = filesystem.resolve_path(path)
additions: list[CommitOperationAdd] = [pickle.loads(addition) for addition in pa.Table.from_batches(iterator, schema=pa.schema({"addition": pa.binary()}))[0].to_pylist()]
num_commits = math.ceil(len(additions) / max_operations_per_commit)
for shard_idx, addition in enumerate(additions):
addition.path_in_repo = resolved_path.path_in_repo.replace("{shard_idx:05d}", f"{shard_idx:05d}")
for i in range(0, num_commits):
operations = additions[i * max_operations_per_commit : (i + 1) * max_operations_per_commit]
commit_message = "Upload using PySpark" + (f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "")
filesystem._api.create_commit(repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision, operations=operations, commit_message=commit_message)
yield pa.record_batch({"path": [addition.path_in_repo for addition in operations]}, schema=pa.schema({"path": pa.string()}))
def write_parquet(df: DataFrame, path: str, **kwargs) -> None:
Write Parquet files to Hugging Face using PyArrow.
It uploads Parquet files in a distributed manner in two steps:
1. Preupload the Parquet files in parallel in a distributed banner
2. Commit the preuploaded files
Authenticate using `huggingface-cli login` or passing a token
using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`
path : str
Path of the file or directory. Prefix with a protocol like `hf://` to read from Hugging Face.
It writes Parquet files in the form "part-xxxxx.parquet", or to a single file if `path ends with ".parquet".
Any additional kwargs are passed to pyarrow.parquet.ParquetWriter.
DataFrame based on parquet file.
>>> spark.createDataFrame(pd.DataFrame({"foo": range(5), "bar": range(5, 10)}))
>>> # Save to one file
>>> write_parquet(df, "hf://datasets/username/dataset/data.parquet")
>>> # OR save to a directory (possibly in many files)
>>> write_parquet(df, "hf://datasets/username/dataset")
filesystem: HfFileSystem = kwargs.pop("filesystem", HfFileSystem(**kwargs.pop("storage_options", {})))
if path.endswith(".parquet") or path.endswith(".pq"):
df = df.coalesce(1)
path += "/part-{shard_idx:05d}.parquet"
partial(_preupload, path=path, schema=to_arrow_schema(df.schema), filesystem=filesystem, **kwargs),
from_arrow_schema(pa.schema({"addition": pa.binary()})),
partial(_commit, path=path, filesystem=filesystem),
from_arrow_schema(pa.schema({"path": pa.string()})),
).collect() |